Author: mattip <matti.pi...@gmail.com> Branch: ufuncapi Changeset: r72139:5dd1671a3061 Date: 2014-06-22 21:44 +0300 http://bitbucket.org/pypy/pypy/changeset/5dd1671a3061/
Log: wip diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py --- a/pypy/module/micronumpy/loop.py +++ b/pypy/module/micronumpy/loop.py @@ -10,6 +10,7 @@ from pypy.module.micronumpy.base import W_NDimArray from pypy.module.micronumpy.iterators import PureShapeIter, AxisIter, \ AllButAxisIter +from pypy.interpreter.argument import Arguments call2_driver = jit.JitDriver( @@ -93,12 +94,13 @@ def call_many_to_one(space, shape, func, res_dtype, w_in, out): # out must hav been built. func needs no calc_type, is usually an # external ufunc - iters_and_states = [i.create_iter(shape) for i in w_in] + iters_and_states = [list(i.create_iter(shape)) for i in w_in] shapelen = len(shape) + out_iter, out_state = out.create_iter(shape) while not out_iter.done(out_state): call_many_to_one_driver.jit_merge_point(shapelen=shapelen, func=func, res_dtype=res_dtype) - vals = [None] + [i_s[0].getitem(i_s[1]) for i_s in iters_and_states] + vals = [i_s[0].getitem(i_s[1]) for i_s in iters_and_states] arglist = space.wrap(vals) out_val = space.call_args(func, Arguments.frompacked(space, arglist)) out_iter.setitem(out_state, out_val.convert_to(space, res_dtype)) @@ -107,6 +109,39 @@ out_state = out_iter.next(out_state) return out +call_many_to_many_driver = jit.JitDriver( + name='numpy_call_many_to_many', + greens=['shapelen', 'func', 'res_dtype'], + reds='auto') + +def call_many_to_many(space, shape, func, res_dtype, w_in, w_out): + # out must hav been built. func needs no calc_type, is usually an + # external ufunc + in_iters_and_states = [list(i.create_iter(shape)) for i in w_in] + shapelen = len(shape) + out_iters_and_states = [list(i.create_iter(shape)) for i in w_out] + # what does the function return? + while not out_iters_and_states[0][0].done(out_iters_and_states[0][1]): + call_many_to_many_driver.jit_merge_point(shapelen=shapelen, func=func, + res_dtype=res_dtype) + vals = [i_s[0].getitem(i_s[1]) for i_s in in_iters_and_states] + arglist = space.wrap(vals) + out_vals = space.call_args(func, Arguments.frompacked(space, arglist)) + # XXX bad form + if not isinstance(out_vals,(list, tuple)): + out_iter, out_state = out_iters_and_states[0] + out_iter.setitem(out_state, out_vals.convert_to(space, res_dtype)) + out_iters_and_states[0][1] = out_iters_and_states[0][0].next(out_iters_and_states[0][1]) + else: + for i in range(len(out_iters_and_states)): + out_iter, out_state = out_iters_and_states[i] + out_iter.setitem(out_state, out_vals[i].convert_to(space, res_dtype)) + out_iters_and_states[i][1] = out_iters_and_states[i][0].next(out_iters_and_states[i][1]) + for i in range(len(iters_and_states)): + in_iters_and_states[i][1] = in_iters_and_states[i][0].next(in_iters_and_states[i][1]) + return out + + def setslice(space, shape, target, source): # note that unlike everything else, target and source here are diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py --- a/pypy/module/micronumpy/ufuncs.py +++ b/pypy/module/micronumpy/ufuncs.py @@ -493,7 +493,7 @@ self.nin = nin self.nout = nout self.nargs = nin + max(nout, 1) # ufuncs can always be called with an out=<> kwarg - if dtypes != 'match' and (len(dtypes) % len(funcs) != 0 or + if dtypes[0] != 'match' and (len(dtypes) % len(funcs) != 0 or len(dtypes) / len(funcs) != self.nargs): raise oefmt(space.w_ValueError, "generic ufunc with %d functions, %d arguments, but %d dtypes", @@ -527,11 +527,13 @@ index = self.type_resolver(space, inargs, outargs) self.alloc_outargs(space, index, inargs, outargs) # XXX handle inner-loop indexing + new_shape = inargs[0].get_shape() + res_dtype = outargs[0].get_dtype() if len(outargs) < 2: - return loop.call_many_to_one(space, new_shape, self.func, + return loop.call_many_to_one(space, new_shape, self.funcs[index], res_dtype, inargs, outargs[0]) - return loop.call_many_to_many(space, new_shape, self.func, - res_dtype, inargs, out) + return loop.call_many_to_many(space, new_shape, self.funcs[index], + res_dtype, inargs, outargs) def type_resolver(self, space, index, outargs): # Find a match for the inargs.dtype in self.dtypes, like @@ -954,16 +956,14 @@ if space.is_none(w_dtypes) and not signature: raise oefmt(space.w_NotImplementedError, 'object dtype requested but not implemented') - if space.isinstance_w(w_dtypes, space.w_str): - if not space.str_w(w_dtypes) == 'match': - raise oefmt(space.w_ValueError, - 'unknown out_dtype value "%s"', space.str_w(w_dtypes)) - dtypes = 'match' elif (space.isinstance_w(w_dtypes, space.w_tuple) or space.isinstance_w(w_dtypes, space.w_list)): dtypes = space.listview(w_dtypes) - for i in range(len(dtypes)): - dtypes[i] = descriptor.decode_w_dtype(space, dtypes[i]) + if space.str_w(dtypes[0]) == 'match': + dtypes = ['match',] + else: + for i in range(len(dtypes)): + dtypes[i] = descriptor.decode_w_dtype(space, dtypes[i]) else: raise oefmt(space.w_ValueError, 'dtypes must be None or a list of dtypes') @@ -976,9 +976,9 @@ raise oefmt(space.w_ValueError, 'identity must be 0, 1, or None') if nin==1 and nout==1 and dtypes == 'match': - w_ret = W_Ufunc1(wrap_ext_func(func[0], name) + w_ret = W_Ufunc1(wrap_ext_func(func[0], name)) elif nin==2 and nout==1 and dtypes == 'match': - w_ret = W_Ufunc2(wrap_ext_func(func[0]), name) + w_ret = W_Ufunc2(wrap_ext_func(func[0], name)) else: w_ret = W_UfuncGeneric(space, func, name, identity, nin, nout, dtypes, signature) if doc: _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit