Author: mattip Branch: matrixmath-dot Changeset: r50046:84755e29506f Date: 2011-12-01 22:59 +0200 http://bitbucket.org/pypy/pypy/changeset/84755e29506f/
Log: add two arg functionality to test_compile diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py --- a/pypy/module/micronumpy/compile.py +++ b/pypy/module/micronumpy/compile.py @@ -30,6 +30,7 @@ pass SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any", "unegative"] +TWO_ARG_FUNCTIONS = ["dot"] class FakeSpace(object): w_ValueError = None @@ -381,17 +382,28 @@ w_res = neg.call(interp.space, [arr]) else: assert False # unreachable code - if isinstance(w_res, BaseArray): - return w_res - if isinstance(w_res, FloatObject): - dtype = interp.space.fromcache(W_Float64Dtype) - elif isinstance(w_res, BoolObject): - dtype = interp.space.fromcache(W_BoolDtype) - else: - dtype = None - return scalar_w(interp.space, dtype, w_res) + elif self.name in TWO_ARG_FUNCTIONS: + if len(self.args) != 2: + raise ArgumentMismatch + arr0 = self.args[0].execute(interp) + arr1 = self.args[1].execute(interp) + if not isinstance(arr0, BaseArray): + raise ArgumentNotAnArray + if not isinstance(arr1, BaseArray): + raise ArgumentNotAnArray + elif self.name == "dot": + w_res = arr0.descr_dot(interp.space, arr1) else: raise WrongFunctionName + if isinstance(w_res, BaseArray): + return w_res + if isinstance(w_res, FloatObject): + dtype = interp.space.fromcache(W_Float64Dtype) + elif isinstance(w_res, BoolObject): + dtype = interp.space.fromcache(W_BoolDtype) + else: + dtype = None + return scalar_w(interp.space, dtype, w_res) _REGEXES = [ ('-?[\d\.]+', 'number'), @@ -525,6 +537,9 @@ args = [] tokens.pop() # lparen while tokens.get(0).name != 'paren_right': + if tokens.get(0).name == 'coma': + tokens.pop() + continue args.append(self.parse_expression(tokens)) return FunctionCall(name, args) diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -519,9 +519,11 @@ return w_res.descr_sum(space) #Do the dims match? my_critical_dim_size = self.shape[-1] - other_critical_dim_size = w_other.shape[0] + other_critical_dim_size = w_other.shape[0] + other_critical_dim_stride = w_other.strides[0] if len(w_other.shape) > 2: other_critical_dim_size = w_other.shape[-2] + other_critical_dim_stride = w_other.strides[-2] if my_critical_dim_size != other_critical_dim_size: raise OperationError(space.w_ValueError, space.wrap( "objects are not aligned")) @@ -529,23 +531,26 @@ out_size = 1 for os in out_shape: out_size *= os + out_ndims = len(out_shape) dtype = interp_ufuncs.find_binop_result_dtype(space, self.find_dtype(), w_other.find_dtype()) #TODO: what should the order be? C or F? arr = NDimArray(out_size, out_shape, dtype=dtype) + out_iter = ArrayIterator(out_size) + #TODO: invalidate self, w_other with arr + me_iter = BroadcastIterator(self,self.shape[:-1] + [1]) + other_iter = BroadcastIterator(self, + w_other.shape[:-2] + [1] + w_other.shape[-1:]) + while not out_iter.done(): + i = OneDimIterator(me_iter.get_offset(), self.strides[-1], self.shape[-1]) + j = OneDimIterator(other_iter.get_offset(), other_critical_dim_stride, other_critical_dim_size) + #Heres what I would like to do, but how? + #value = sum(mult_with_iters(self, i, w_other, j)) + #arr.setitem(out_iter, value) + out_iter = out_iter.next(out_ndims) + me_iter = me_iter.next(0) + other_iter = other_iter.next(0) return arr - out_iter = ArrayIterator(out_size) - me_iter = BroadcastIterator(self,self.shape[:len(self.size)-1] + [1]) - other_iter = BroadcastIter(self, - w_other.shape[:-2] + [1] + w_other.shape[-1]) - call2 = instantiate(Call2) - call2.left = self - call2.right = w_other - call2.calc_dtype = None - call2.size = my_critical_dim_size - - while not out_iter.done(): - pass def get_concrete(self): diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py --- a/pypy/module/micronumpy/interp_ufuncs.py +++ b/pypy/module/micronumpy/interp_ufuncs.py @@ -161,7 +161,6 @@ self.signature, w_lhs.signature, w_rhs.signature ]) new_shape = shape_agreement(space, w_lhs.shape, w_rhs.shape) - edf = jkl w_res = Call2(new_sig, new_shape, calc_dtype, res_dtype, w_lhs, w_rhs) w_lhs.add_invalidates(w_res) diff --git a/pypy/module/micronumpy/test/test_compile.py b/pypy/module/micronumpy/test/test_compile.py --- a/pypy/module/micronumpy/test/test_compile.py +++ b/pypy/module/micronumpy/test/test_compile.py @@ -232,3 +232,11 @@ a -> 3 """) assert interp.results[0].value.val == 11 + def test_dot(self): + interp = self.run(""" + a = [[1, 2], [3, 4]] + b = [[5, 6], [7, 8]] + c = dot(a, b) + c -> 0 -> 0 + """) + assert interp.results[0].value.val == 19 _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit