Author: Ilya Osadchiy <osadchiy.i...@gmail.com> Branch: numpy-comparison Changeset: r47014:edb6c31894de Date: 2011-09-02 10:37 +0300 http://bitbucket.org/pypy/pypy/changeset/edb6c31894de/
Log: Initial implementation (tests pass, translation fails) diff --git a/pypy/module/micronumpy/__init__.py b/pypy/module/micronumpy/__init__.py --- a/pypy/module/micronumpy/__init__.py +++ b/pypy/module/micronumpy/__init__.py @@ -25,17 +25,18 @@ 'floor': 'interp_ufuncs.floor', 'maximum': 'interp_ufuncs.maximum', 'minimum': 'interp_ufuncs.minimum', - 'multiply': 'interp_ufuncs.multiply', - 'negative': 'interp_ufuncs.negative', - 'reciprocal': 'interp_ufuncs.reciprocal', - 'sign': 'interp_ufuncs.sign', - 'subtract': 'interp_ufuncs.subtract', - 'sin': 'interp_ufuncs.sin', - 'cos': 'interp_ufuncs.cos', - 'tan': 'interp_ufuncs.tan', - 'arcsin': 'interp_ufuncs.arcsin', - 'arccos': 'interp_ufuncs.arccos', - 'arctan': 'interp_ufuncs.arctan', + 'multiply': 'interp_ufuncs.multiply', + 'negative': 'interp_ufuncs.negative', + 'reciprocal': 'interp_ufuncs.reciprocal', + 'sign': 'interp_ufuncs.sign', + 'subtract': 'interp_ufuncs.subtract', + 'sin': 'interp_ufuncs.sin', + 'cos': 'interp_ufuncs.cos', + 'tan': 'interp_ufuncs.tan', + 'arcsin': 'interp_ufuncs.arcsin', + 'arccos': 'interp_ufuncs.arccos', + 'arctan': 'interp_ufuncs.arctan', + 'equal': 'interp_ufuncs.equal', } appleveldefs = { diff --git a/pypy/module/micronumpy/interp_dtype.py b/pypy/module/micronumpy/interp_dtype.py --- a/pypy/module/micronumpy/interp_dtype.py +++ b/pypy/module/micronumpy/interp_dtype.py @@ -125,6 +125,15 @@ )) return impl +def bool_binop(func): + @functools.wraps(func) + def impl(self, v1, v2): + return self.box(func(self, + self.for_computation(self.unbox(v1)), + self.for_computation(self.unbox(v2)), + )) + return impl + def unaryop(func): @functools.wraps(func) def impl(self, v): @@ -147,6 +156,25 @@ def div(self, v1, v2): return v1 / v2 + @bool_binop + def eq(self, v1, v2): + return v1 == v2 + @bool_binop + def ne(self, v1, v2): + return v1 != v2 + @bool_binop + def lt(self, v1, v2): + return v1 < v2 + @bool_binop + def le(self, v1, v2): + return v1 <= v2 + @bool_binop + def gt(self, v1, v2): + return v1 > v2 + @bool_binop + def ge(self, v1, v2): + return v1 >= v2 + @unaryop def pos(self, v): return +v @@ -166,8 +194,8 @@ def bool(self, v): return bool(self.for_computation(self.unbox(v))) - def ne(self, v1, v2): - return self.for_computation(self.unbox(v1)) != self.for_computation(self.unbox(v2)) +# def ne(self, v1, v2): +# return self.for_computation(self.unbox(v1)) != self.for_computation(self.unbox(v2)) class FloatArithmeticDtype(ArithmaticTypeMixin): @@ -355,4 +383,4 @@ num = interp_attrproperty("num", cls=W_Dtype), kind = interp_attrproperty("kind", cls=W_Dtype), shape = GetSetProperty(W_Dtype.descr_get_shape), -) \ No newline at end of file +) diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -74,6 +74,13 @@ descr_pow = _binop_impl(interp_ufuncs.power) descr_mod = _binop_impl(interp_ufuncs.mod) + descr_eq = _binop_impl(interp_ufuncs.equal) + descr_ne = _binop_impl(interp_ufuncs.not_equal) + descr_lt = _binop_impl(interp_ufuncs.less) + descr_le = _binop_impl(interp_ufuncs.less_equal) + descr_gt = _binop_impl(interp_ufuncs.greater) + descr_ge = _binop_impl(interp_ufuncs.greater_equal) + def _binop_right_impl(w_ufunc): def impl(self, space, w_other): w_other = scalar_w(space, @@ -152,7 +159,7 @@ size=size, i=i, result=result, cur_best=cur_best) new_best = getattr(dtype, op_name)(cur_best, self.eval(i)) - if dtype.ne(new_best, cur_best): + if dtype.unbox(dtype.ne(new_best, cur_best)): result = i cur_best = new_best i += 1 @@ -350,11 +357,12 @@ """ Class for representing virtual arrays, such as binary ops or ufuncs """ - def __init__(self, signature, res_dtype): + def __init__(self, signature, res_dtype, calc_dtype): BaseArray.__init__(self) self.forced_result = None self.signature = signature self.res_dtype = res_dtype + self.calc_dtype = calc_dtype def _del_sources(self): # Function for deleting references to source arrays, to allow garbage-collecting them @@ -402,7 +410,7 @@ class Call1(VirtualArray): def __init__(self, signature, res_dtype, values): - VirtualArray.__init__(self, signature, res_dtype) + VirtualArray.__init__(self, signature, res_dtype, res_dtype) self.values = values def _del_sources(self): @@ -427,8 +435,8 @@ """ Intermediate class for performing binary operations. """ - def __init__(self, signature, res_dtype, left, right): - VirtualArray.__init__(self, signature, res_dtype) + def __init__(self, signature, res_dtype, calc_dtype, left, right): + VirtualArray.__init__(self, signature, res_dtype, calc_dtype) self.left = left self.right = right @@ -444,14 +452,14 @@ return self.right.find_size() def _eval(self, i): - lhs = self.left.eval(i).convert_to(self.res_dtype) - rhs = self.right.eval(i).convert_to(self.res_dtype) + lhs = self.left.eval(i).convert_to(self.calc_dtype) + rhs = self.right.eval(i).convert_to(self.calc_dtype) sig = jit.promote(self.signature) assert isinstance(sig, signature.Signature) call_sig = sig.components[0] assert isinstance(call_sig, signature.Call2) - return call_sig.func(self.res_dtype, lhs, rhs) + return call_sig.func(self.calc_dtype, lhs, rhs).convert_to(self.res_dtype) class ViewArray(BaseArray): """ @@ -610,6 +618,13 @@ __repr__ = interp2app(BaseArray.descr_repr), __str__ = interp2app(BaseArray.descr_str), + __eq__ = interp2app(BaseArray.descr_eq), + __ne__ = interp2app(BaseArray.descr_ne), + __lt__ = interp2app(BaseArray.descr_lt), + __le__ = interp2app(BaseArray.descr_le), + __gt__ = interp2app(BaseArray.descr_gt), + __ge__ = interp2app(BaseArray.descr_ge), + dtype = GetSetProperty(BaseArray.descr_get_dtype), shape = GetSetProperty(BaseArray.descr_get_shape), diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py --- a/pypy/module/micronumpy/interp_ufuncs.py +++ b/pypy/module/micronumpy/interp_ufuncs.py @@ -24,9 +24,9 @@ return w_res return func_with_new_name(impl, "%s_dispatcher" % func.__name__) -def ufunc2(func=None, promote_to_float=False): +def ufunc2(func=None, promote_to_float=False, bool_result=False): if func is None: - return lambda func: ufunc2(func, promote_to_float) + return lambda func: ufunc2(func, promote_to_float, bool_result) call_sig = signature.Call2(func) def impl(space, w_lhs, w_rhs): @@ -35,17 +35,25 @@ w_lhs = convert_to_array(space, w_lhs) w_rhs = convert_to_array(space, w_rhs) - res_dtype = find_binop_result_dtype(space, + calc_dtype = find_binop_result_dtype(space, w_lhs.find_dtype(), w_rhs.find_dtype(), promote_to_float=promote_to_float, ) + # Some operations return bool regardless of input type + if bool_result: + res_dtype = space.fromcache(interp_dtype.W_BoolDtype) + else: + res_dtype = calc_dtype if isinstance(w_lhs, Scalar) and isinstance(w_rhs, Scalar): - return func(res_dtype, w_lhs.value, w_rhs.value).wrap(space) + lhs = w_lhs.value.convert_to(calc_dtype) + rhs = w_rhs.value.convert_to(calc_dtype) + interm_res = func(calc_dtype, lhs, rhs) + return interm_res.convert_to(res_dtype).wrap(space) new_sig = signature.Signature.find_sig([ call_sig, w_lhs.signature, w_rhs.signature ]) - w_res = Call2(new_sig, res_dtype, w_lhs, w_rhs) + w_res = Call2(new_sig, res_dtype, calc_dtype, w_lhs, w_rhs) w_lhs.add_invalidates(w_res) w_rhs.add_invalidates(w_res) return w_res @@ -123,6 +131,13 @@ ("maximum", "max", 2), ("minimum", "min", 2), + ("equal", "eq", 2, {"bool_result": True}), + ("not_equal", "ne", 2, {"bool_result": True}), + ("less", "lt", 2, {"bool_result": True}), + ("less_equal", "le", 2, {"bool_result": True}), + ("greater", "gt", 2, {"bool_result": True}), + ("greater_equal", "ge", 2, {"bool_result": True}), + ("copysign", "copysign", 2, {"promote_to_float": True}), ("positive", "pos", 1), diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -510,6 +510,34 @@ assert array([1.2, 5]).dtype is dtype(float) assert array([]).dtype is dtype(float) + def test_comparison(self): + from numpy import array, dtype + a = array(range(5)) + b = array(range(5), dtype=float) + for func in [ + lambda x, y: x == y, + lambda x, y: x != y, + lambda x, y: x < y, + lambda x, y: x <= y, + lambda x, y: x > y, + lambda x, y: x >= y, + ]: + _a3 = func (a, 3) + assert _a3.dtype is dtype(bool) + for i in xrange(5): + assert _a3[i] == (True if func(a[i], 3) else False) + _b3 = func (b, 3) + assert _b3.dtype is dtype(bool) + for i in xrange(5): + assert _b3[i] == (True if func(b[i], 3) else False) + _3a = func (3, a) + assert _3a.dtype is dtype(bool) + for i in xrange(5): + assert _3a[i] == (True if func(3, a[i]) else False) + _3b = func (3, b) + assert _3b.dtype is dtype(bool) + for i in xrange(5): + assert _3b[i] == (True if func(3, b[i]) else False) class AppTestSupport(object): def setup_class(cls): @@ -522,4 +550,4 @@ a = fromstring(self.data) for i in range(4): assert a[i] == i + 1 - raises(ValueError, fromstring, "abc") \ No newline at end of file + raises(ValueError, fromstring, "abc") diff --git a/pypy/module/micronumpy/test/test_ufuncs.py b/pypy/module/micronumpy/test/test_ufuncs.py --- a/pypy/module/micronumpy/test/test_ufuncs.py +++ b/pypy/module/micronumpy/test/test_ufuncs.py @@ -267,3 +267,11 @@ b = arctan(a) assert math.isnan(b[0]) + def test_comparison(self): + from numpy import array, dtype, equal + assert equal(3, 3) is True + assert equal(3, 4) is False + assert equal(3.0, 3.0) is True + assert equal(3.0, 3.5) is False + assert equal(3.0, 3) is True + assert equal(3.0, 4) is False _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit