Author: mattip
Branch: matrixmath-dot
Changeset: r51432:f62709780578
Date: 2012-01-18 02:22 +0200
http://bitbucket.org/pypy/pypy/changeset/f62709780578/
Log: passes a test, needs cleanup
diff --git a/pypy/module/micronumpy/interp_iter.py
b/pypy/module/micronumpy/interp_iter.py
--- a/pypy/module/micronumpy/interp_iter.py
+++ b/pypy/module/micronumpy/interp_iter.py
@@ -2,7 +2,7 @@
from pypy.rlib import jit
from pypy.rlib.objectmodel import instantiate
from pypy.module.micronumpy.strides import calculate_broadcast_strides,\
- calculate_slice_strides
+ calculate_slice_strides, calculate_dot_strides
class BaseTransform(object):
pass
@@ -16,6 +16,11 @@
def __init__(self, res_shape):
self.res_shape = res_shape
+class DotTransform(BaseTransform):
+ def __init__(self, res_shape, skip_dims):
+ self.res_shape = res_shape
+ self.skip_dims = skip_dims
+
class BaseIterator(object):
def next(self, shapelen):
raise NotImplementedError
@@ -85,6 +90,10 @@
self.strides,
self.backstrides, t.chunks)
return ViewIterator(r[1], r[2], r[3], r[0])
+ elif isinstance(t, DotTransform):
+ r = calculate_dot_strides(self.strides, self.backstrides,
+ t.res_shape, t.skip_dims)
+ return ViewIterator(self.offset, r[0], r[1], t.res_shape)
@jit.unroll_safe
def next(self, shapelen):
@@ -130,6 +139,7 @@
def transform(self, arr, t):
pass
+
class AxisIterator(BaseIterator):
def __init__(self, start, dim, shape, strides, backstrides):
self.res_shape = shape[:]
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -3,13 +3,14 @@
from pypy.interpreter.gateway import interp2app, NoneNotWrapped
from pypy.interpreter.typedef import TypeDef, GetSetProperty
from pypy.module.micronumpy import interp_ufuncs, interp_dtype, signature
-from pypy.module.micronumpy.strides import calculate_slice_strides
+from pypy.module.micronumpy.strides import calculate_slice_strides,\
+ calculate_dot_strides
from pypy.rlib import jit
from pypy.rpython.lltypesystem import lltype, rffi
from pypy.tool.sourcetools import func_with_new_name
from pypy.rlib.rstring import StringBuilder
from pypy.module.micronumpy.interp_iter import ArrayIterator, OneDimIterator,\
- SkipLastAxisIterator
+ SkipLastAxisIterator, ViewIterator
numpy_driver = jit.JitDriver(
greens=['shapelen', 'sig'],
@@ -211,6 +212,28 @@
n_old_elems_to_use *= old_shape[oldI]
return new_strides
+def match_dot_shapes(space, self, other):
+ my_critical_dim_size = self.shape[-1]
+ other_critical_dim_size = other.shape[0]
+ other_critical_dim = 0
+ other_critical_dim_stride = other.strides[0]
+ out_shape = []
+ if len(other.shape) > 1:
+ other_critical_dim = len(other.shape) - 2
+ other_critical_dim_size = other.shape[other_critical_dim]
+ other_critical_dim_stride = other.strides[other_critical_dim]
+ assert other_critical_dim >= 0
+ out_shape += self.shape[:-1] + \
+ other.shape[0:other_critical_dim] + \
+ other.shape[other_critical_dim + 1:]
+ elif len(other.shape) > 0:
+ #dot does not reduce for scalars
+ out_shape += self.shape[:-1]
+ if my_critical_dim_size != other_critical_dim_size:
+ raise OperationError(space.w_ValueError, space.wrap(
+ "objects are not aligned"))
+ return out_shape, other_critical_dim
+
class BaseArray(Wrappable):
_attrs_ = ["invalidates", "shape", 'size']
@@ -384,70 +407,62 @@
the second-to-last of `b`::
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])'''
- w_other = convert_to_array(space, w_other)
- if isinstance(w_other, Scalar):
- return self.descr_mul(space, w_other)
- elif len(self.shape) < 2 and len(w_other.shape) < 2:
- w_res = self.descr_mul(space, w_other)
+ other = convert_to_array(space, w_other)
+ if isinstance(other, Scalar):
+ return self.descr_mul(space, other)
+ elif len(self.shape) < 2 and len(other.shape) < 2:
+ w_res = self.descr_mul(space, other)
assert isinstance(w_res, BaseArray)
return w_res.descr_sum(space, space.wrap(-1))
dtype = interp_ufuncs.find_binop_result_dtype(space,
- self.find_dtype(), w_other.find_dtype())
- if self.size < 1 and w_other.size < 1:
+ self.find_dtype(), other.find_dtype())
+ if self.size < 1 and other.size < 1:
#numpy compatability
return scalar_w(space, dtype, space.wrap(0))
#Do the dims match?
- my_critical_dim_size = self.shape[-1]
- other_critical_dim_size = w_other.shape[0]
- other_critical_dim = 0
- other_critical_dim_stride = w_other.strides[0]
- out_shape = []
- if len(w_other.shape) > 1:
- other_critical_dim = len(w_other.shape) - 2
- other_critical_dim_size = w_other.shape[other_critical_dim]
- other_critical_dim_stride = w_other.strides[other_critical_dim]
- assert other_critical_dim >= 0
- out_shape += self.shape[:-1] + \
- w_other.shape[0:other_critical_dim] + \
- w_other.shape[other_critical_dim + 1:]
- elif len(w_other.shape) > 0:
- #dot does not reduce for scalars
- out_shape += self.shape[:-1]
- if my_critical_dim_size != other_critical_dim_size:
- raise OperationError(space.w_ValueError, space.wrap(
- "objects are not aligned"))
+ out_shape, other_critical_dim = match_dot_shapes(space, self, other)
out_size = 1
- for os in out_shape:
- out_size *= os
- out_ndims = len(out_shape)
- # TODO: what should the order be? C or F?
- arr = W_NDimArray(out_size, out_shape, dtype=dtype)
- # TODO: this is all a bogus mess of previous work,
- # rework within the context of transformations
- '''
- out_iter = ViewIterator(arr.start, arr.strides, arr.backstrides,
arr.shape)
- # TODO: invalidate self, w_other with arr ?
- while not out_iter.done():
- my_index = self.start
- other_index = w_other.start
- i = 0
- while i < len(self.shape) - 1:
- my_index += out_iter.indices[i] * self.strides[i]
- i += 1
- for j in range(len(w_other.shape) - 2):
- other_index += out_iter.indices[i] * w_other.strides[j]
- other_index += out_iter.indices[-1] * w_other.strides[-1]
- w_ssd = space.newlist([space.wrap(my_index),
- space.wrap(len(self.shape) - 1)])
- w_osd = space.newlist([space.wrap(other_index),
- space.wrap(other_critical_dim)])
- w_res = self.descr_mul(space, w_other)
- assert isinstance(w_res, BaseArray)
- value = w_res.descr_sum(space)
- arr.setitem(out_iter.get_offset(), value)
- out_iter = out_iter.next(out_ndims)
- '''
- return arr
+ for o in out_shape:
+ out_size *= o
+ result = W_NDimArray(out_size, out_shape, dtype)
+ # given a.shape == [3, 5, 7],
+ # b.shape == [2, 7, 4]
+ # result.shape == [3, 5, 2, 4]
+ # all iterators shapes should be [3, 5, 2, 7, 4]
+ # result should skip dims 3 which is results.ndims - 1
+ # a should skip 2, 4 which is a.ndims-1 + range(b.ndims)
+ # except where it==(b.ndims-2)
+ # b should skip 0, 1
+ mul = interp_ufuncs.get(space).multiply.func
+ add = interp_ufuncs.get(space).add.func
+ broadcast_shape = self.shape[:-1] + other.shape
+ #Aww, cmon, this is the product of a warped mind.
+ left_skip = [len(self.shape) - 1 + i for i in range(len(other.shape))
if i != other_critical_dim]
+ right_skip = range(len(self.shape) - 1)
+ arr = DotArray(mul, 'DotName', out_shape, dtype, self, other,
+ left_skip, right_skip)
+ arr.broadcast_shape = broadcast_shape
+ arr.result_skip = [len(out_shape) - 1]
+ #Make this lazy someday...
+ sig = signature.find_sig(signature.DotSignature(mul, 'dot', dtype,
+ self.create_sig(), other.create_sig()), arr)
+ assert isinstance(sig, signature.DotSignature)
+ self.do_dot_loop(sig, result, arr, add)
+ return result
+
+ def do_dot_loop(self, sig, result, arr, add):
+ frame = sig.create_frame(arr)
+ shapelen = len(arr.broadcast_shape)
+ _r = calculate_dot_strides(result.strides, result.backstrides,
+ arr.broadcast_shape, arr.result_skip)
+ ri = ViewIterator(0, _r[0], _r[1], arr.broadcast_shape)
+ while not frame.done():
+ v = sig.eval(frame, arr).convert_to(sig.calc_dtype)
+ z = result.getitem(ri.offset)
+ value = add(sig.calc_dtype, v, result.getitem(ri.offset))
+ result.setitem(ri.offset, value)
+ frame.next(shapelen)
+ ri = ri.next(shapelen)
def get_concrete(self):
raise NotImplementedError
@@ -919,6 +934,23 @@
left, right)
self.dim = dim
+class DotArray(Call2):
+ """ NOTE: this is only used as a container, you should never
+ encounter such things in the wild. Remove this comment
+ when we'll make Dot lazy
+ """
+ _immutable_fields_ = ['left', 'right']
+
+ def __init__(self, ufunc, name, shape, dtype, left, right, left_skip,
right_skip):
+ Call2.__init__(self, ufunc, name, shape, dtype, dtype,
+ left, right)
+ self.left_skip = left_skip
+ self.right_skip = right_skip
+ def create_sig(self):
+ #if self.forced_result is not None:
+ # return self.forced_result.create_sig()
+ assert NotImplementedError
+
class ConcreteArray(BaseArray):
""" An array that have actual storage, whether owned or not
"""
diff --git a/pypy/module/micronumpy/interp_ufuncs.py
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -192,17 +192,17 @@
sig=sig,
identity=identity,
shapelen=shapelen, arr=arr)
- iter = frame.get_final_iter()
+ iterator = frame.get_final_iter()
v = sig.eval(frame, arr).convert_to(sig.calc_dtype)
- if iter.first_line:
+ if iterator.first_line:
if identity is not None:
value = self.func(sig.calc_dtype, identity, v)
else:
value = v
else:
- cur = arr.left.getitem(iter.offset)
+ cur = arr.left.getitem(iterator.offset)
value = self.func(sig.calc_dtype, cur, v)
- arr.left.setitem(iter.offset, value)
+ arr.left.setitem(iterator.offset, value)
frame.next(shapelen)
def reduce_loop(self, shapelen, sig, frame, value, obj, dtype):
diff --git a/pypy/module/micronumpy/signature.py
b/pypy/module/micronumpy/signature.py
--- a/pypy/module/micronumpy/signature.py
+++ b/pypy/module/micronumpy/signature.py
@@ -2,7 +2,7 @@
from pypy.rlib.rarithmetic import intmask
from pypy.module.micronumpy.interp_iter import ViewIterator, ArrayIterator, \
ConstantIterator, AxisIterator, ViewTransform,\
- BroadcastTransform
+ BroadcastTransform, DotTransform
from pypy.rlib.jit import hint, unroll_safe, promote
""" Signature specifies both the numpy expression that has been constructed
@@ -331,7 +331,6 @@
assert isinstance(arr, Call2)
lhs = self.left.eval(frame, arr.left).convert_to(self.calc_dtype)
rhs = self.right.eval(frame, arr.right).convert_to(self.calc_dtype)
-
return self.binfunc(self.calc_dtype, lhs, rhs)
def debug_repr(self):
@@ -450,3 +449,21 @@
def debug_repr(self):
return 'AxisReduceSig(%s, %s)' % (self.name, self.right.debug_repr())
+
+class DotSignature(Call2):
+ def _invent_numbering(self, cache, allnumbers):
+ self.left._invent_numbering(new_cache(), allnumbers)
+ self.right._invent_numbering(new_cache(), allnumbers)
+
+ def _create_iter(self, iterlist, arraylist, arr, transforms):
+ from pypy.module.micronumpy.interp_numarray import DotArray
+
+ assert isinstance(arr, DotArray)
+ rtransforms = transforms + [DotTransform(arr.broadcast_shape,
arr.right_skip)]
+ ltransforms = transforms + [DotTransform(arr.broadcast_shape,
arr.left_skip)]
+ self.left._create_iter(iterlist, arraylist, arr.left, ltransforms)
+ self.right._create_iter(iterlist, arraylist, arr.right, rtransforms)
+
+ def debug_repr(self):
+ return 'DotSig(%s, %s %s)' % (self.name, self.right.debug_repr(),
+ self.left.debug_repr())
diff --git a/pypy/module/micronumpy/strides.py
b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -37,3 +37,17 @@
rstrides = [0] * (len(res_shape) - len(orig_shape)) + rstrides
rbackstrides = [0] * (len(res_shape) - len(orig_shape)) + rbackstrides
return rstrides, rbackstrides
+
+def calculate_dot_strides(strides, backstrides, res_shape, skip_dims):
+ rstrides = []
+ rbackstrides = []
+ j=0
+ for i in range(len(res_shape)):
+ if i in skip_dims:
+ rstrides.append(0)
+ rbackstrides.append(0)
+ else:
+ rstrides.append(strides[j])
+ rbackstrides.append(backstrides[j])
+ j += 1
+ return rstrides, rbackstrides
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -867,7 +867,7 @@
assert c.any() == False
def test_dot(self):
- from _numpypy import array, dot
+ from _numpypy import array, dot, arange
a = array(range(5))
assert a.dot(a) == 30.0
@@ -876,13 +876,12 @@
assert dot(range(5), range(5)) == 30
assert (dot(5, [1, 2, 3]) == [5, 10, 15]).all()
- a = array([range(4), range(4, 8), range(8, 12)])
- b = array([range(3), range(3, 6), range(6, 9), range(9, 12)])
+ a = arange(12).reshape(3, 4)
+ b = arange(12).reshape(4, 3)
c = a.dot(b)
assert (c == [[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]).all()
- a = array([[range(4), range(4, 8), range(8, 12)],
- [range(12, 16), range(16, 20), range(20, 24)]])
+ a = arange(24).reshape(2, 3, 4)
raises(ValueError, "a.dot(a)")
b = a[0, :, :].T
#Superfluous shape test makes the intention of the test clearer
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit