Author: Alex Gaynor <[email protected]>
Branch: numpy-exp
Changeset: r44169:c53b9efed043
Date: 2011-05-14 16:08 -0500
http://bitbucket.org/pypy/pypy/changeset/c53b9efed043/
Log: Convert the numpy interpreter to just be an AST walker, it's much
less code this way.
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -3,7 +3,6 @@
from pypy.interpreter.gateway import interp2app, unwrap_spec
from pypy.interpreter.typedef import TypeDef
from pypy.rlib import jit
-from pypy.rlib.nonconst import NonConstant
from pypy.rpython.lltypesystem import lltype
from pypy.tool.sourcetools import func_with_new_name
@@ -18,171 +17,13 @@
TP = lltype.Array(lltype.Float, hints={'nolength': True})
-numpy_driver = jit.JitDriver(greens = ['bytecode_pos', 'bytecode'],
- reds = ['result_size', 'i', 'frame',
- 'result'],
- virtualizables = ['frame'])
-
-class ComputationFrame(object):
- _virtualizable2_ = ['valuestackdepth', 'valuestack[*]',
- 'array_pos', 'arrays[*]',
- 'float_pos', 'floats[*]',
- 'function_pos', 'functions[*]',
- ]
-
- def __init__(self, arrays, floats, functions):
- self = jit.hint(self, access_directly=True, fresh_virtualizable=True)
- self.valuestackdepth = 0
- self.arrays = arrays
- self.array_pos = len(arrays)
- self.floats = floats
- if NonConstant(0):
- self.floats = [3.5] # annotator hack for test_zjit
- self.float_pos = len(floats)
- self.functions = functions
- if NonConstant(0):
- self.functions = [dummy1, dummy2] # another annotator hack
- self.function_pos = len(functions)
- self.valuestack = [0.0] * (len(arrays) + len(floats))
-
- def reset(self):
- self.valuestackdepth = 0
- self.array_pos = len(self.arrays)
- self.float_pos = len(self.floats)
- self.function_pos = len(self.functions)
-
- def _get_item(value):
- pos_name = value + "_pos"
- array_name = value + "s"
- def impl(self):
- p = getattr(self, pos_name) - 1
- assert p >= 0
- res = getattr(self, array_name)[p]
- setattr(self, pos_name, p)
- return res
- return func_with_new_name(impl, "get" + value)
-
- getarray = _get_item("array")
- getfloat = _get_item("float")
- getfunction = _get_item("function")
-
- def popvalue(self):
- v = self.valuestackdepth - 1
- assert v >= 0
- res = self.valuestack[v]
- self.valuestackdepth = v
- return res
-
- def pushvalue(self, v):
- self.valuestack[self.valuestackdepth] = v
- self.valuestackdepth += 1
-
-class Code(object):
- """
- A chunk of bytecode.
- """
-
- def __init__(self, bytecode, arrays=None, floats=None, functions=None):
- self.bytecode = bytecode
- self.arrays = arrays or []
- self.floats = floats or []
- self.functions = functions or []
-
- def merge(self, code, other):
- """
- Merge this bytecode with the other bytecode, using ``code`` as the
- bytecode instruction for performing the merge.
- """
-
- return Code(code + self.bytecode + other.bytecode,
- self.arrays + other.arrays,
- self.floats + other.floats,
- self.functions + other.functions)
-
- def intern(self):
- # the point of these hacks is to intern the bytecode string otherwise
- # we have to compile new assembler each time, which sucks (we still
- # have to compile new bytecode, but too bad)
- try:
- self.bytecode = JITCODES[self.bytecode]
- except KeyError:
- JITCODES[self.bytecode] = self.bytecode
-
- def compute(self):
- """
- Crunch a ``Code`` full of bytecode.
- """
-
- bytecode = self.bytecode
- result_size = self.arrays[0].size
- result = SingleDimArray(result_size)
- bytecode_pos = len(bytecode) - 1
- i = 0
- frame = ComputationFrame(self.arrays, self.floats, self.functions)
- while i < result_size:
- numpy_driver.jit_merge_point(bytecode=bytecode, result=result,
- result_size=result_size,
- i=i, frame=frame,
- bytecode_pos=bytecode_pos)
- if bytecode_pos == -1:
- bytecode_pos = len(bytecode) - 1
- frame.reset()
- result.storage[i] = frame.valuestack[0]
- i += 1
- numpy_driver.can_enter_jit(bytecode=bytecode, result=result,
- result_size=result_size,
- i=i, frame=frame,
- bytecode_pos=bytecode_pos)
- else:
- opcode = bytecode[bytecode_pos]
- if opcode == 'l':
- # Load array.
- val = frame.getarray().storage[i]
- frame.pushvalue(val)
- elif opcode == 'f':
- # Load float.
- val = frame.getfloat()
- frame.pushvalue(val)
- elif opcode == 'a':
- # Add.
- a = frame.popvalue()
- b = frame.popvalue()
- frame.pushvalue(a + b)
- elif opcode == 's':
- # Subtract
- a = frame.popvalue()
- b = frame.popvalue()
- frame.pushvalue(a - b)
- elif opcode == 'm':
- # Multiply.
- a = frame.popvalue()
- b = frame.popvalue()
- frame.pushvalue(a * b)
- elif opcode == 'd':
- a = frame.popvalue()
- b = frame.popvalue()
- frame.pushvalue(a / b)
- elif opcode == 'c':
- func = frame.getfunction()
- val = frame.popvalue()
- frame.pushvalue(func(val))
- else:
- raise NotImplementedError(
- "Can't handle bytecode instruction %s" % opcode)
- bytecode_pos -= 1
- return result
-
-JITCODES = {}
+numpy_driver = jit.JitDriver(greens = ['bytecode'],
+ reds = ['result_size', 'i', 'self', 'result'])
class BaseArray(Wrappable):
def __init__(self):
self.invalidates = []
- def force(self):
- code = self.compile()
- code.intern()
- return code.compute()
-
def invalidated(self):
for arr in self.invalidates:
arr.force_if_needed()
@@ -208,9 +49,6 @@
descr_mul = _binop_impl("m")
descr_div = _binop_impl("d")
- def compile(self):
- raise NotImplementedError("abstract base class")
-
def get_concrete(self):
raise NotImplementedError
@@ -236,8 +74,14 @@
BaseArray.__init__(self)
self.float_value = float_value
- def compile(self):
- return Code('f', floats=[self.float_value])
+ def bytecode(self):
+ return "f"
+
+ def find_size(self):
+ raise ValueError
+
+ def eval(self, i):
+ return self.float_value
class VirtualArray(BaseArray):
"""
@@ -247,19 +91,32 @@
BaseArray.__init__(self)
self.forced_result = None
- def compile(self):
- if self.forced_result is not None:
- return self.forced_result.compile()
- return self._compile()
+ def compute(self):
+ i = 0
+ bytecode = self.bytecode()
+ result_size = self.find_size()
+ result = SingleDimArray(result_size)
+ while i < result_size:
+ numpy_driver.jit_merge_point(bytecode=bytecode,
+ result_size=result_size, i=i,
+ self=self, result=result)
+ result.storage[i] = self.eval(i)
+ i += 1
+ return result
def force_if_needed(self):
if self.forced_result is None:
- self.forced_result = self.force()
+ self.forced_result = self.compute()
def get_concrete(self):
self.force_if_needed()
return self.forced_result
+ def eval(self, i):
+ if self.forced_result is not None:
+ return self.forced_result.eval(i)
+ return self._eval(i)
+
class BinOp(VirtualArray):
"""
@@ -272,10 +129,28 @@
self.left = left
self.right = right
- def _compile(self):
- left_code = self.left.compile()
- right_code = self.right.compile()
- return left_code.merge(self.opcode, right_code)
+ def bytecode(self):
+ return self.opcode + self.left.bytecode() + self.right.bytecode()
+
+ def find_size(self):
+ try:
+ return self.left.find_size()
+ except ValueError:
+ pass
+ return self.right.find_size()
+
+ def _eval(self, i):
+ lhs, rhs = self.left.eval(i), self.right.eval(i)
+ if self.opcode == "a":
+ return lhs + rhs
+ elif self.opcode == "s":
+ return lhs - rhs
+ elif self.opcode == "m":
+ return lhs * rhs
+ elif self.opcode == "d":
+ return lhs / rhs
+ else:
+ raise NotImplementedError("Don't know opcode %s" % self.opcode)
class Call(VirtualArray):
def __init__(self, function, values):
@@ -283,8 +158,14 @@
self.function = function
self.values = values
- def _compile(self):
- return Code('', functions=[self.function]).merge('c',
self.values.compile())
+ def bytecode(self):
+ return "c" + self.values.bytecode()
+
+ def find_size(self):
+ return self.values.find_size()
+
+ def _eval(self, i):
+ return self.function(self.values.eval(i))
class SingleDimArray(BaseArray):
@@ -295,12 +176,18 @@
flavor='raw', track_allocation=False)
# XXX find out why test_zjit explodes with trackign of allocations
- def compile(self):
- return Code('l', arrays=[self])
-
def get_concrete(self):
return self
+ def bytecode(self):
+ return "l"
+
+ def find_size(self):
+ return self.size
+
+ def eval(self, i):
+ return self.storage[i]
+
def getindex(self, space, item):
if item >= self.size:
raise operationerrfmt(space.w_IndexError,
diff --git a/pypy/module/micronumpy/test/test_zjit.py
b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -2,6 +2,7 @@
from pypy.module.micronumpy.interp_numarray import (SingleDimArray, BinOp,
FloatWrapper, Call)
from pypy.module.micronumpy.interp_ufuncs import negative_impl
+from pypy.rlib.nonconst import NonConstant
class FakeSpace(object):
@@ -17,10 +18,10 @@
def f(i):
ar = SingleDimArray(i)
if i:
- v = BinOp('a', ar, ar)
+ v = BinOp(NonConstant('a'), ar, ar)
else:
v = ar
- return v.force().storage[3]
+ return v.get_concrete().storage[3]
result = self.meta_interp(f, [5], listops=True, backendopt=True)
self.check_loops({'getarrayitem_raw': 2, 'float_add': 1,
@@ -34,10 +35,10 @@
def f(i):
ar = SingleDimArray(i)
if i:
- v = BinOp('a', ar, FloatWrapper(4.5))
+ v = BinOp(NonConstant('a'), ar, FloatWrapper(4.5))
else:
v = ar
- return v.force().storage[3]
+ return v.get_concrete().storage[3]
result = self.meta_interp(f, [5], listops=True, backendopt=True)
self.check_loops({"getarrayitem_raw": 1, "float_add": 1,
@@ -53,7 +54,7 @@
v1 = BinOp('a', ar, FloatWrapper(4.5))
v2 = BinOp('m', v1, FloatWrapper(4.5))
v1.force_if_needed()
- return v2.force().storage[3]
+ return v2.get_concrete().storage[3]
result = self.meta_interp(f, [5], listops=True, backendopt=True)
# This is the sum of the ops for both loops, however if you remove the
@@ -68,9 +69,9 @@
space = self.space
def f(i):
ar = SingleDimArray(i)
- v1 = BinOp('a', ar, ar)
+ v1 = BinOp(NonConstant('a'), ar, ar)
v2 = Call(negative_impl, v1)
- return v2.force().storage[3]
+ return v2.get_concrete().storage[3]
result = self.meta_interp(f, [5], listops=True, backendopt=True)
self.check_loops({"getarrayitem_raw": 2, "float_add": 1, "float_neg":
1,
diff --git a/pypy/rlib/nonconst.py b/pypy/rlib/nonconst.py
--- a/pypy/rlib/nonconst.py
+++ b/pypy/rlib/nonconst.py
@@ -18,6 +18,12 @@
def __nonzero__(self):
return bool(self.__dict__['constant'])
+ def __eq__(self, other):
+ return self.__dict__['constant'] == other
+
+ def __add__(self, other):
+ return self.__dict__['constant'] + other
+
class EntryNonConstant(ExtRegistryEntry):
_about_ = NonConstant
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit