Author: Alex Gaynor <[email protected]>
Branch: numpy-exp
Changeset: r44169:c53b9efed043
Date: 2011-05-14 16:08 -0500
http://bitbucket.org/pypy/pypy/changeset/c53b9efed043/

Log:    Convert the numpy interpreter to just be an AST walker, it's much
        less code this way.

diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -3,7 +3,6 @@
 from pypy.interpreter.gateway import interp2app, unwrap_spec
 from pypy.interpreter.typedef import TypeDef
 from pypy.rlib import jit
-from pypy.rlib.nonconst import NonConstant
 from pypy.rpython.lltypesystem import lltype
 from pypy.tool.sourcetools import func_with_new_name
 
@@ -18,171 +17,13 @@
 
 TP = lltype.Array(lltype.Float, hints={'nolength': True})
 
-numpy_driver = jit.JitDriver(greens = ['bytecode_pos', 'bytecode'],
-                             reds = ['result_size', 'i', 'frame',
-                                     'result'],
-                             virtualizables = ['frame'])
-
-class ComputationFrame(object):
-    _virtualizable2_ = ['valuestackdepth', 'valuestack[*]',
-                        'array_pos', 'arrays[*]',
-                        'float_pos', 'floats[*]',
-                        'function_pos', 'functions[*]',
-                        ]
-
-    def __init__(self, arrays, floats, functions):
-        self = jit.hint(self, access_directly=True, fresh_virtualizable=True)
-        self.valuestackdepth = 0
-        self.arrays = arrays
-        self.array_pos = len(arrays)
-        self.floats = floats
-        if NonConstant(0):
-            self.floats = [3.5] # annotator hack for test_zjit
-        self.float_pos = len(floats)
-        self.functions = functions
-        if NonConstant(0):
-            self.functions = [dummy1, dummy2] # another annotator hack
-        self.function_pos = len(functions)
-        self.valuestack = [0.0] * (len(arrays) + len(floats))
-
-    def reset(self):
-        self.valuestackdepth = 0
-        self.array_pos = len(self.arrays)
-        self.float_pos = len(self.floats)
-        self.function_pos = len(self.functions)
-
-    def _get_item(value):
-        pos_name = value + "_pos"
-        array_name = value + "s"
-        def impl(self):
-            p = getattr(self, pos_name) - 1
-            assert p >= 0
-            res = getattr(self, array_name)[p]
-            setattr(self, pos_name, p)
-            return res
-        return func_with_new_name(impl, "get" + value)
-
-    getarray = _get_item("array")
-    getfloat = _get_item("float")
-    getfunction = _get_item("function")
-
-    def popvalue(self):
-        v = self.valuestackdepth - 1
-        assert v >= 0
-        res = self.valuestack[v]
-        self.valuestackdepth = v
-        return res
-
-    def pushvalue(self, v):
-        self.valuestack[self.valuestackdepth] = v
-        self.valuestackdepth += 1
-
-class Code(object):
-    """
-    A chunk of bytecode.
-    """
-
-    def __init__(self, bytecode, arrays=None, floats=None, functions=None):
-        self.bytecode = bytecode
-        self.arrays = arrays or []
-        self.floats = floats or []
-        self.functions = functions or []
-
-    def merge(self, code, other):
-        """
-        Merge this bytecode with the other bytecode, using ``code`` as the
-        bytecode instruction for performing the merge.
-        """
-
-        return Code(code + self.bytecode + other.bytecode,
-            self.arrays + other.arrays,
-            self.floats + other.floats,
-            self.functions + other.functions)
-
-    def intern(self):
-        # the point of these hacks is to intern the bytecode string otherwise
-        # we have to compile new assembler each time, which sucks (we still
-        # have to compile new bytecode, but too bad)
-        try:
-            self.bytecode = JITCODES[self.bytecode]
-        except KeyError:
-            JITCODES[self.bytecode] = self.bytecode
-
-    def compute(self):
-        """
-        Crunch a ``Code`` full of bytecode.
-        """
-
-        bytecode = self.bytecode
-        result_size = self.arrays[0].size
-        result = SingleDimArray(result_size)
-        bytecode_pos = len(bytecode) - 1
-        i = 0
-        frame = ComputationFrame(self.arrays, self.floats, self.functions)
-        while i < result_size:
-            numpy_driver.jit_merge_point(bytecode=bytecode, result=result,
-                                         result_size=result_size,
-                                         i=i, frame=frame,
-                                         bytecode_pos=bytecode_pos)
-            if bytecode_pos == -1:
-                bytecode_pos = len(bytecode) - 1
-                frame.reset()
-                result.storage[i] = frame.valuestack[0]
-                i += 1
-                numpy_driver.can_enter_jit(bytecode=bytecode, result=result,
-                                           result_size=result_size,
-                                           i=i, frame=frame,
-                                           bytecode_pos=bytecode_pos)
-            else:
-                opcode = bytecode[bytecode_pos]
-                if opcode == 'l':
-                    # Load array.
-                    val = frame.getarray().storage[i]
-                    frame.pushvalue(val)
-                elif opcode == 'f':
-                    # Load float.
-                    val = frame.getfloat()
-                    frame.pushvalue(val)
-                elif opcode == 'a':
-                    # Add.
-                    a = frame.popvalue()
-                    b = frame.popvalue()
-                    frame.pushvalue(a + b)
-                elif opcode == 's':
-                    # Subtract
-                    a = frame.popvalue()
-                    b = frame.popvalue()
-                    frame.pushvalue(a - b)
-                elif opcode == 'm':
-                    # Multiply.
-                    a = frame.popvalue()
-                    b = frame.popvalue()
-                    frame.pushvalue(a * b)
-                elif opcode == 'd':
-                    a = frame.popvalue()
-                    b = frame.popvalue()
-                    frame.pushvalue(a / b)
-                elif opcode == 'c':
-                    func = frame.getfunction()
-                    val = frame.popvalue()
-                    frame.pushvalue(func(val))
-                else:
-                    raise NotImplementedError(
-                        "Can't handle bytecode instruction %s" % opcode)
-                bytecode_pos -= 1
-        return result
-
-JITCODES = {}
+numpy_driver = jit.JitDriver(greens = ['bytecode'],
+                             reds = ['result_size', 'i', 'self', 'result'])
 
 class BaseArray(Wrappable):
     def __init__(self):
         self.invalidates = []
 
-    def force(self):
-        code = self.compile()
-        code.intern()
-        return code.compute()
-
     def invalidated(self):
         for arr in self.invalidates:
             arr.force_if_needed()
@@ -208,9 +49,6 @@
     descr_mul = _binop_impl("m")
     descr_div = _binop_impl("d")
 
-    def compile(self):
-        raise NotImplementedError("abstract base class")
-
     def get_concrete(self):
         raise NotImplementedError
 
@@ -236,8 +74,14 @@
         BaseArray.__init__(self)
         self.float_value = float_value
 
-    def compile(self):
-        return Code('f', floats=[self.float_value])
+    def bytecode(self):
+        return "f"
+
+    def find_size(self):
+        raise ValueError
+
+    def eval(self, i):
+        return self.float_value
 
 class VirtualArray(BaseArray):
     """
@@ -247,19 +91,32 @@
         BaseArray.__init__(self)
         self.forced_result = None
 
-    def compile(self):
-        if self.forced_result is not None:
-            return self.forced_result.compile()
-        return self._compile()
+    def compute(self):
+        i = 0
+        bytecode = self.bytecode()
+        result_size = self.find_size()
+        result = SingleDimArray(result_size)
+        while i < result_size:
+            numpy_driver.jit_merge_point(bytecode=bytecode,
+                                         result_size=result_size, i=i,
+                                         self=self, result=result)
+            result.storage[i] = self.eval(i)
+            i += 1
+        return result
 
     def force_if_needed(self):
         if self.forced_result is None:
-            self.forced_result = self.force()
+            self.forced_result = self.compute()
 
     def get_concrete(self):
         self.force_if_needed()
         return self.forced_result
 
+    def eval(self, i):
+        if self.forced_result is not None:
+            return self.forced_result.eval(i)
+        return self._eval(i)
+
 
 class BinOp(VirtualArray):
     """
@@ -272,10 +129,28 @@
         self.left = left
         self.right = right
 
-    def _compile(self):
-        left_code = self.left.compile()
-        right_code = self.right.compile()
-        return left_code.merge(self.opcode, right_code)
+    def bytecode(self):
+        return self.opcode + self.left.bytecode() + self.right.bytecode()
+
+    def find_size(self):
+        try:
+            return self.left.find_size()
+        except ValueError:
+            pass
+        return self.right.find_size()
+
+    def _eval(self, i):
+        lhs, rhs = self.left.eval(i), self.right.eval(i)
+        if self.opcode == "a":
+            return lhs + rhs
+        elif self.opcode == "s":
+            return lhs - rhs
+        elif self.opcode == "m":
+            return lhs * rhs
+        elif self.opcode == "d":
+            return lhs / rhs
+        else:
+            raise NotImplementedError("Don't know opcode %s" % self.opcode)
 
 class Call(VirtualArray):
     def __init__(self, function, values):
@@ -283,8 +158,14 @@
         self.function = function
         self.values = values
 
-    def _compile(self):
-        return Code('', functions=[self.function]).merge('c', 
self.values.compile())
+    def bytecode(self):
+        return "c" + self.values.bytecode()
+
+    def find_size(self):
+        return self.values.find_size()
+
+    def _eval(self, i):
+        return self.function(self.values.eval(i))
 
 
 class SingleDimArray(BaseArray):
@@ -295,12 +176,18 @@
                                      flavor='raw', track_allocation=False)
         # XXX find out why test_zjit explodes with trackign of allocations
 
-    def compile(self):
-        return Code('l', arrays=[self])
-
     def get_concrete(self):
         return self
 
+    def bytecode(self):
+        return "l"
+
+    def find_size(self):
+        return self.size
+
+    def eval(self, i):
+        return self.storage[i]
+
     def getindex(self, space, item):
         if item >= self.size:
             raise operationerrfmt(space.w_IndexError,
diff --git a/pypy/module/micronumpy/test/test_zjit.py 
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -2,6 +2,7 @@
 from pypy.module.micronumpy.interp_numarray import (SingleDimArray, BinOp,
     FloatWrapper, Call)
 from pypy.module.micronumpy.interp_ufuncs import negative_impl
+from pypy.rlib.nonconst import NonConstant
 
 
 class FakeSpace(object):
@@ -17,10 +18,10 @@
         def f(i):
             ar = SingleDimArray(i)
             if i:
-                v = BinOp('a', ar, ar)
+                v = BinOp(NonConstant('a'), ar, ar)
             else:
                 v = ar
-            return v.force().storage[3]
+            return v.get_concrete().storage[3]
 
         result = self.meta_interp(f, [5], listops=True, backendopt=True)
         self.check_loops({'getarrayitem_raw': 2, 'float_add': 1,
@@ -34,10 +35,10 @@
         def f(i):
             ar = SingleDimArray(i)
             if i:
-                v = BinOp('a', ar, FloatWrapper(4.5))
+                v = BinOp(NonConstant('a'), ar, FloatWrapper(4.5))
             else:
                 v = ar
-            return v.force().storage[3]
+            return v.get_concrete().storage[3]
 
         result = self.meta_interp(f, [5], listops=True, backendopt=True)
         self.check_loops({"getarrayitem_raw": 1, "float_add": 1,
@@ -53,7 +54,7 @@
             v1 = BinOp('a', ar, FloatWrapper(4.5))
             v2 = BinOp('m', v1, FloatWrapper(4.5))
             v1.force_if_needed()
-            return v2.force().storage[3]
+            return v2.get_concrete().storage[3]
 
         result = self.meta_interp(f, [5], listops=True, backendopt=True)
         # This is the sum of the ops for both loops, however if you remove the
@@ -68,9 +69,9 @@
         space = self.space
         def f(i):
             ar = SingleDimArray(i)
-            v1 = BinOp('a', ar, ar)
+            v1 = BinOp(NonConstant('a'), ar, ar)
             v2 = Call(negative_impl, v1)
-            return v2.force().storage[3]
+            return v2.get_concrete().storage[3]
 
         result = self.meta_interp(f, [5], listops=True, backendopt=True)
         self.check_loops({"getarrayitem_raw": 2, "float_add": 1, "float_neg": 
1,
diff --git a/pypy/rlib/nonconst.py b/pypy/rlib/nonconst.py
--- a/pypy/rlib/nonconst.py
+++ b/pypy/rlib/nonconst.py
@@ -18,6 +18,12 @@
     def __nonzero__(self):
         return bool(self.__dict__['constant'])
 
+    def __eq__(self, other):
+        return self.__dict__['constant'] == other
+
+    def __add__(self, other):
+        return self.__dict__['constant'] + other
+
 class EntryNonConstant(ExtRegistryEntry):
     _about_ = NonConstant
 
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to