Author: Alex Gaynor <alex.gay...@gmail.com> Branch: Changeset: r47087:618b0bba96a2 Date: 2011-09-05 10:02 -0700 http://bitbucket.org/pypy/pypy/changeset/618b0bba96a2/
Log: (snus, alex) Added the comparison functions to micronumpy. This is mostly the work from the numpy-comparisons branch, refactored by me. diff --git a/pypy/module/micronumpy/__init__.py b/pypy/module/micronumpy/__init__.py --- a/pypy/module/micronumpy/__init__.py +++ b/pypy/module/micronumpy/__init__.py @@ -26,13 +26,19 @@ ("copysign", "copysign"), ("cos", "cos"), ("divide", "divide"), + ("equal", "equal"), ("exp", "exp"), ("fabs", "fabs"), ("floor", "floor"), + ("greater", "greater"), + ("greater_equal", "greater_equal"), + ("less", "less"), + ("less_equal", "less_equal"), ("maximum", "maximum"), ("minimum", "minimum"), ("multiply", "multiply"), ("negative", "negative"), + ("not_equal", "not_equal"), ("reciprocal", "reciprocal"), ("sign", "sign"), ("sin", "sin"), diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py --- a/pypy/module/micronumpy/interp_dtype.py +++ b/pypy/module/micronumpy/interp_dtype.py @@ -129,6 +129,16 @@ )) return impl +def raw_binop(func): + # Returns the result unwrapped. + @functools.wraps(func) + def impl(self, v1, v2): + return func(self, + self.for_computation(self.unbox(v1)), + self.for_computation(self.unbox(v2)) + ) + return impl + def unaryop(func): @functools.wraps(func) def impl(self, v): @@ -170,8 +180,24 @@ def bool(self, v): return bool(self.for_computation(self.unbox(v))) + @raw_binop + def eq(self, v1, v2): + return v1 == v2 + @raw_binop def ne(self, v1, v2): - return self.for_computation(self.unbox(v1)) != self.for_computation(self.unbox(v2)) + return v1 != v2 + @raw_binop + def lt(self, v1, v2): + return v1 < v2 + @raw_binop + def le(self, v1, v2): + return v1 <= v2 + @raw_binop + def gt(self, v1, v2): + return v1 > v2 + @raw_binop + def ge(self, v1, v2): + return v1 >= v2 class FloatArithmeticDtype(ArithmaticTypeMixin): diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -74,6 +74,13 @@ descr_pow = _binop_impl("power") descr_mod = _binop_impl("mod") + descr_eq = _binop_impl("equal") + descr_ne = _binop_impl("not_equal") + descr_lt = _binop_impl("less") + descr_le = _binop_impl("less_equal") + descr_gt = _binop_impl("greater") + descr_ge = _binop_impl("greater_equal") + def _binop_right_impl(ufunc_name): def impl(self, space, w_other): w_other = scalar_w(space, @@ -404,10 +411,11 @@ """ Intermediate class for performing binary operations. """ - def __init__(self, signature, res_dtype, left, right): + def __init__(self, signature, calc_dtype, res_dtype, left, right): VirtualArray.__init__(self, signature, res_dtype) self.left = left self.right = right + self.calc_dtype = calc_dtype def _del_sources(self): self.left = None @@ -421,14 +429,14 @@ return self.right.find_size() def _eval(self, i): - lhs = self.left.eval(i).convert_to(self.res_dtype) - rhs = self.right.eval(i).convert_to(self.res_dtype) + lhs = self.left.eval(i).convert_to(self.calc_dtype) + rhs = self.right.eval(i).convert_to(self.calc_dtype) sig = jit.promote(self.signature) assert isinstance(sig, signature.Signature) call_sig = sig.components[0] assert isinstance(call_sig, signature.Call2) - return call_sig.func(self.res_dtype, lhs, rhs) + return call_sig.func(self.calc_dtype, lhs, rhs) class ViewArray(BaseArray): """ @@ -573,18 +581,28 @@ __pos__ = interp2app(BaseArray.descr_pos), __neg__ = interp2app(BaseArray.descr_neg), __abs__ = interp2app(BaseArray.descr_abs), + __add__ = interp2app(BaseArray.descr_add), __sub__ = interp2app(BaseArray.descr_sub), __mul__ = interp2app(BaseArray.descr_mul), __div__ = interp2app(BaseArray.descr_div), __pow__ = interp2app(BaseArray.descr_pow), __mod__ = interp2app(BaseArray.descr_mod), + __radd__ = interp2app(BaseArray.descr_radd), __rsub__ = interp2app(BaseArray.descr_rsub), __rmul__ = interp2app(BaseArray.descr_rmul), __rdiv__ = interp2app(BaseArray.descr_rdiv), __rpow__ = interp2app(BaseArray.descr_rpow), __rmod__ = interp2app(BaseArray.descr_rmod), + + __eq__ = interp2app(BaseArray.descr_eq), + __ne__ = interp2app(BaseArray.descr_ne), + __lt__ = interp2app(BaseArray.descr_lt), + __le__ = interp2app(BaseArray.descr_le), + __gt__ = interp2app(BaseArray.descr_gt), + __ge__ = interp2app(BaseArray.descr_ge), + __repr__ = interp2app(BaseArray.descr_repr), __str__ = interp2app(BaseArray.descr_str), diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py --- a/pypy/module/micronumpy/interp_ufuncs.py +++ b/pypy/module/micronumpy/interp_ufuncs.py @@ -113,10 +113,11 @@ argcount = 2 def __init__(self, func, name, promote_to_float=False, promote_bools=False, - identity=None): + identity=None, comparison_func=False): W_Ufunc.__init__(self, name, promote_to_float, promote_bools, identity) self.func = func + self.comparison_func = comparison_func self.signature = signature.Call2(func) self.reduce_signature = signature.BaseSignature() @@ -127,18 +128,25 @@ [w_lhs, w_rhs] = args_w w_lhs = convert_to_array(space, w_lhs) w_rhs = convert_to_array(space, w_rhs) - res_dtype = find_binop_result_dtype(space, + calc_dtype = find_binop_result_dtype(space, w_lhs.find_dtype(), w_rhs.find_dtype(), promote_to_float=self.promote_to_float, promote_bools=self.promote_bools, ) + if self.comparison_func: + res_dtype = space.fromcache(interp_dtype.W_BoolDtype) + else: + res_dtype = calc_dtype if isinstance(w_lhs, Scalar) and isinstance(w_rhs, Scalar): - return self.func(res_dtype, w_lhs.value, w_rhs.value).wrap(space) + return self.func(calc_dtype, + w_lhs.value.convert_to(calc_dtype), + w_rhs.value.convert_to(calc_dtype) + ).wrap(space) new_sig = signature.Signature.find_sig([ self.signature, w_lhs.signature, w_rhs.signature ]) - w_res = Call2(new_sig, res_dtype, w_lhs, w_rhs) + w_res = Call2(new_sig, calc_dtype, res_dtype, w_lhs, w_rhs) w_lhs.add_invalidates(w_res) w_rhs.add_invalidates(w_res) return w_res @@ -209,13 +217,16 @@ return space.fromcache(interp_dtype.W_Float64Dtype) -def ufunc_dtype_caller(ufunc_name, op_name, argcount): +def ufunc_dtype_caller(space, ufunc_name, op_name, argcount, comparison_func): if argcount == 1: def impl(res_dtype, value): return getattr(res_dtype, op_name)(value) elif argcount == 2: def impl(res_dtype, lvalue, rvalue): - return getattr(res_dtype, op_name)(lvalue, rvalue) + res = getattr(res_dtype, op_name)(lvalue, rvalue) + if comparison_func: + res = space.fromcache(interp_dtype.W_BoolDtype).box(res) + return res return func_with_new_name(impl, ufunc_name) class UfuncState(object): @@ -229,6 +240,13 @@ ("mod", "mod", 2, {"promote_bools": True}), ("power", "pow", 2, {"promote_bools": True}), + ("equal", "eq", 2, {"comparison_func": True}), + ("not_equal", "ne", 2, {"comparison_func": True}), + ("less", "lt", 2, {"comparison_func": True}), + ("less_equal", "le", 2, {"comparison_func": True}), + ("greater", "gt", 2, {"comparison_func": True}), + ("greater_equal", "ge", 2, {"comparison_func": True}), + ("maximum", "max", 2), ("minimum", "min", 2), @@ -262,7 +280,9 @@ identity = space.fromcache(interp_dtype.W_Int64Dtype).adapt_val(identity) extra_kwargs["identity"] = identity - func = ufunc_dtype_caller(ufunc_name, op_name, argcount) + func = ufunc_dtype_caller(space, ufunc_name, op_name, argcount, + comparison_func=extra_kwargs.get("comparison_func", False) + ) if argcount == 1: ufunc = W_Ufunc1(func, ufunc_name, **extra_kwargs) elif argcount == 2: diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -557,6 +557,26 @@ assert array([1.2, 5]).dtype is dtype(float) assert array([]).dtype is dtype(float) + def test_comparison(self): + import operator + from numpy import array, dtype + + a = array(range(5)) + b = array(range(5), float) + for func in [ + operator.eq, operator.ne, operator.lt, operator.le, operator.gt, + operator.ge + ]: + c = func(a, 3) + assert c.dtype is dtype(bool) + for i in xrange(5): + assert c[i] == func(a[i], 3) + + c = func(b, 3) + assert c.dtype is dtype(bool) + for i in xrange(5): + assert c[i] == func(b[i], 3) + class AppTestSupport(object): def setup_class(cls): diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py --- a/pypy/module/micronumpy/test/test_ufuncs.py +++ b/pypy/module/micronumpy/test/test_ufuncs.py @@ -310,4 +310,30 @@ assert add.reduce([1, 2, 3]) == 6 assert maximum.reduce([1]) == 1 assert maximum.reduce([1, 2, 3]) == 3 - raises(ValueError, maximum.reduce, []) \ No newline at end of file + raises(ValueError, maximum.reduce, []) + + def test_comparisons(self): + import operator + from numpy import equal, not_equal, less, less_equal, greater, greater_equal + + for ufunc, func in [ + (equal, operator.eq), + (not_equal, operator.ne), + (less, operator.lt), + (less_equal, operator.le), + (greater, operator.gt), + (greater_equal, operator.ge), + ]: + for a, b in [ + (3, 3), + (3, 4), + (4, 3), + (3.0, 3.0), + (3.0, 3.5), + (3.5, 3.0), + (3.0, 3), + (3, 3.0), + (3.5, 3), + (3, 3.5), + ]: + assert ufunc(a, b) is func(a, b) \ No newline at end of file _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit