Author: Alex Gaynor <alex.gay...@gmail.com> Branch: numpy-dtype-alt Changeset: r46729:3a0ecdbc1565 Date: 2011-08-23 07:31 -0500 http://bitbucket.org/pypy/pypy/changeset/3a0ecdbc1565/
Log: Added dtype guessing, this also fixes the return type on things like numpy.maximum(1, 2), which used to be buggy and return a float. diff --git a/TODO.txt b/TODO.txt --- a/TODO.txt +++ b/TODO.txt @@ -1,5 +1,4 @@ TODO for mering numpy-dtype-alt =============================== -* More operations on more dtypes -* dtype guessing +* More operations on more dtypes (including copy-paste reduction) \ No newline at end of file diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py --- a/pypy/module/micronumpy/interp_dtype.py +++ b/pypy/module/micronumpy/interp_dtype.py @@ -218,9 +218,29 @@ class IntegerArithmeticDtype(object): _mixin_ = True + # XXX: reduce the copy paste @binop def add(self, v1, v2): return widen(v1) + widen(v2) + @binop + def sub(self, v1, v2): + return widen(v1) - widen(v2) + @binop + def mul(self, v1, v2): + return widen(v1) * widen(v2) + @binop + def div(self, v1, v2): + return widen(v1) / widen(v2) + @binop + def mod(self, v1, v2): + return widen(v1) % widen(v2) + + @binop + def max(self, v1, v2): + return max(widen(v1), widen(v2)) + + def bool(self, v): + return bool(widen(self.unbox(v))) def str_format(self, item): return str(widen(self.unbox(item))) diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -37,10 +37,19 @@ self.invalidates.append(other) def descr__new__(space, w_subtype, w_size_or_iterable, w_dtype=None): + l = space.listview(w_size_or_iterable) + if space.is_w(w_dtype, space.w_None): + w_dtype = None + for w_item in l: + w_dtype = interp_ufuncs.find_dtype_for_scalar(space, w_item, w_dtype) + if w_dtype is space.fromcache(interp_dtype.W_Float64Dtype): + break + if w_dtype is None: + w_dtype = space.w_None + dtype = space.interp_w(interp_dtype.W_Dtype, space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype) ) - l = space.listview(w_size_or_iterable) arr = SingleDimArray(len(l), dtype=dtype) i = 0 for w_elem in l: @@ -71,7 +80,10 @@ def _binop_right_impl(w_ufunc): def impl(self, space, w_other): - w_other = scalar_w(space, interp_dtype.W_Float64Dtype, w_other) + w_other = scalar_w(space, + interp_ufuncs.find_dtype_for_scalar(space, w_other, self.find_dtype()), + w_other + ) return w_ufunc(space, w_other, self) return func_with_new_name(impl, "binop_right_%s_impl" % w_ufunc.__name__) @@ -295,11 +307,11 @@ slice_driver.jit_merge_point(signature=source.signature, step=step, stop=stop, i=i, j=j, source=source, dest=dest) - dest.setitem(i, source.eval(j)) + dest.setitem(i, source.eval(j).convert_to(dest.find_dtype())) j += 1 i += step -def convert_to_array (space, w_obj): +def convert_to_array(space, w_obj): if isinstance(w_obj, BaseArray): return w_obj elif space.issequence_w(w_obj): @@ -309,16 +321,11 @@ return w_obj else: # If it's a scalar - return scalar_w(space, interp_dtype.W_Float64Dtype, w_obj) + dtype = interp_ufuncs.find_dtype_for_scalar(space, w_obj) + return scalar_w(space, dtype, w_obj) -@specialize.arg(1) def scalar_w(space, dtype, w_obj): - return Scalar(space.fromcache(dtype), scalar(space, dtype, w_obj)) - -@specialize.arg(1) -def scalar(space, dtype, w_obj): - dtype = space.fromcache(dtype) - return dtype.unwrap(space, w_obj) + return Scalar(dtype, dtype.unwrap(space, w_obj)) class Scalar(BaseArray): """ diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py --- a/pypy/module/micronumpy/interp_ufuncs.py +++ b/pypy/module/micronumpy/interp_ufuncs.py @@ -13,15 +13,14 @@ convert_to_array, Scalar) w_obj = convert_to_array(space, w_obj) - if isinstance(w_obj, Scalar): - res_dtype = space.fromcache(interp_dtype.W_Float64Dtype) - return func(res_dtype, w_obj.value).wrap(space) - - new_sig = signature.Signature.find_sig([call_sig, w_obj.signature]) res_dtype = find_unaryop_result_dtype(space, w_obj.find_dtype(), promote_to_float=promote_to_float, ) + if isinstance(w_obj, Scalar): + return func(res_dtype, w_obj.value.convert_to(res_dtype)).wrap(space) + + new_sig = signature.Signature.find_sig([call_sig, w_obj.signature]) w_res = Call1(new_sig, res_dtype, w_obj) w_obj.add_invalidates(w_res) return w_res @@ -38,15 +37,16 @@ w_lhs = convert_to_array(space, w_lhs) w_rhs = convert_to_array(space, w_rhs) - if isinstance(w_lhs, Scalar) and isinstance(w_rhs, Scalar): - res_dtype = space.fromcache(interp_dtype.W_Float64Dtype) - return func(res_dtype, w_lhs.value, w_rhs.value).wrap(space) - - new_sig = signature.Signature.find_sig([call_sig, w_lhs.signature, w_rhs.signature]) res_dtype = find_binop_result_dtype(space, w_lhs.find_dtype(), w_rhs.find_dtype(), promote_to_float=promote_to_float, ) + if isinstance(w_lhs, Scalar) and isinstance(w_rhs, Scalar): + return func(res_dtype, w_lhs.value, w_rhs.value).wrap(space) + + new_sig = signature.Signature.find_sig([ + call_sig, w_lhs.signature, w_rhs.signature + ]) w_res = Call2(new_sig, res_dtype, w_lhs, w_rhs) w_lhs.add_invalidates(w_res) w_rhs.add_invalidates(w_res) @@ -87,6 +87,21 @@ assert False return dt +def find_dtype_for_scalar(space, w_obj, current_guess=None): + w_type = space.type(w_obj) + + bool_dtype = space.fromcache(interp_dtype.W_BoolDtype) + int64_dtype = space.fromcache(interp_dtype.W_Int64Dtype) + + if space.is_w(w_type, space.w_bool): + if current_guess is None: + return bool_dtype + elif space.is_w(w_type, space.w_int): + if (current_guess is None or current_guess is bool_dtype or + current_guess is int64_dtype): + return int64_dtype + return space.fromcache(interp_dtype.W_Float64Dtype) + def ufunc_dtype_caller(ufunc_name, op_name, argcount, **kwargs): if argcount == 1: diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -52,7 +52,7 @@ def test_repr(self): from numpy import array, zeros - a = array(range(5)) + a = array(range(5), float) assert repr(a) == "array([0.0, 1.0, 2.0, 3.0, 4.0])" a = zeros(1001) assert repr(a) == "array([0.0, 0.0, 0.0, ..., 0.0, 0.0, 0.0])" @@ -63,7 +63,7 @@ def test_repr_slice(self): from numpy import array, zeros - a = array(range(5)) + a = array(range(5), float) b = a[1::2] assert repr(b) == "array([1.0, 3.0])" a = zeros(2002) @@ -72,7 +72,7 @@ def test_str(self): from numpy import array, zeros - a = array(range(5)) + a = array(range(5), float) assert str(a) == "[0.0 1.0 2.0 3.0 4.0]" assert str((2*a)[:]) == "[0.0 2.0 4.0 6.0 8.0]" a = zeros(1001) @@ -88,7 +88,7 @@ def test_str_slice(self): from numpy import array, zeros - a = array(range(5)) + a = array(range(5), float) b = a[1::2] assert str(b) == "[1.0 3.0]" a = zeros(2002) @@ -144,7 +144,7 @@ def test_setslice_list(self): from numpy import array - a = array(range(5)) + a = array(range(5), float) b = [0., 1.] a[1:4:2] = b assert a[1] == 0. @@ -152,7 +152,7 @@ def test_setslice_constant(self): from numpy import array - a = array(range(5)) + a = array(range(5), float) a[1:4:2] = 0. assert a[1] == 0. assert a[3] == 0. @@ -261,7 +261,7 @@ def test_div_other(self): from numpy import array a = array(range(5)) - b = array([2, 2, 2, 2, 2]) + b = array([2, 2, 2, 2, 2], float) c = a / b for i in range(5): assert c[i] == i / 2.0 @@ -275,7 +275,7 @@ def test_pow(self): from numpy import array - a = array(range(5)) + a = array(range(5), float) b = a ** a for i in range(5): print b[i], i**i @@ -283,7 +283,7 @@ def test_pow_other(self): from numpy import array - a = array(range(5)) + a = array(range(5), float) b = array([2, 2, 2, 2, 2]) c = a ** b for i in range(5): @@ -291,7 +291,7 @@ def test_pow_constant(self): from numpy import array - a = array(range(5)) + a = array(range(5), float) b = a ** 2 for i in range(5): assert b[i] == i ** 2 @@ -484,6 +484,16 @@ for i in xrange(5): assert b[i] == 2.5 * a[i] + def test_dtype_guessing(self): + from numpy import array, dtype + + assert array([True]).dtype is dtype(bool) + assert array([True, 1]).dtype is dtype(long) + assert array([1, 2, 3]).dtype is dtype(long) + assert array([1.2, True]).dtype is dtype(float) + assert array([1.2, 5]).dtype is dtype(float) + assert array([]).dtype is dtype(float) + class AppTestSupport(object): def setup_class(cls): @@ -496,5 +506,4 @@ a = fromstring(self.data) for i in range(4): assert a[i] == i + 1 - raises(ValueError, fromstring, "abc") - + raises(ValueError, fromstring, "abc") \ No newline at end of file diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py --- a/pypy/module/micronumpy/test/test_ufuncs.py +++ b/pypy/module/micronumpy/test/test_ufuncs.py @@ -110,6 +110,10 @@ for i in range(3): assert c[i] == max(a[i], b[i]) + x = maximum(2, 3) + assert x == 3 + assert type(x) is int + def test_multiply(self): from numpy import array, multiply diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py --- a/pypy/module/micronumpy/test/test_zjit.py +++ b/pypy/module/micronumpy/test/test_zjit.py @@ -34,7 +34,7 @@ ar = SingleDimArray(i, dtype=self.float64_dtype) v = interp_ufuncs.add(self.space, ar, - scalar_w(self.space, W_Float64Dtype, self.space.wrap(4.5)) + scalar_w(self.space, self.float64_dtype, self.space.wrap(4.5)) ) assert isinstance(v, BaseArray) return v.get_concrete().eval(3).val @@ -181,9 +181,9 @@ def f(i): ar = SingleDimArray(i, dtype=self.float64_dtype) - v1 = interp_ufuncs.add(space, ar, scalar_w(space, W_Float64Dtype, space.wrap(4.5))) + v1 = interp_ufuncs.add(space, ar, scalar_w(space, self.float64_dtype, space.wrap(4.5))) assert isinstance(v1, BaseArray) - v2 = interp_ufuncs.multiply(space, v1, scalar_w(space, W_Float64Dtype, space.wrap(4.5))) + v2 = interp_ufuncs.multiply(space, v1, scalar_w(space, self.float64_dtype, space.wrap(4.5))) v1.force_if_needed() assert isinstance(v2, BaseArray) return v2.get_concrete().eval(3).val _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit