Author: mattip <[email protected]>
Branch: numpy-fixes
Changeset: r76982:531f883e82b9
Date: 2015-05-02 21:40 +0300
http://bitbucket.org/pypy/pypy/changeset/531f883e82b9/

Log:    refactor comparison_func -> bool_result, special case logical_and,
        logical_or

diff --git a/pypy/module/micronumpy/test/test_object_arrays.py 
b/pypy/module/micronumpy/test/test_object_arrays.py
--- a/pypy/module/micronumpy/test/test_object_arrays.py
+++ b/pypy/module/micronumpy/test/test_object_arrays.py
@@ -52,8 +52,6 @@
         import numpy as np
         import sys
 
-        if '__pypy__' in sys.builtin_module_names:
-            skip('need to refactor use of raw_xxx_op in types to make this 
work')
         a = np.array(["foo"], dtype=object)
         b = np.array([1], dtype=object)
         d = np.array([complex(1, 10)], dtype=object)
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -298,13 +298,17 @@
     def ge(self, v1, v2):
         return v1 >= v2
 
-    @raw_binary_op
+    @simple_binary_op
     def logical_and(self, v1, v2):
-        return bool(v1) and bool(v2)
+        if bool(v1) and bool(v2):
+            return Bool._True
+        return Bool._False
 
-    @raw_binary_op
+    @simple_binary_op
     def logical_or(self, v1, v2):
-        return bool(v1) or bool(v2)
+        if bool(v1) or bool(v2):
+            return Bool._True
+        return Bool._False
 
     @raw_unary_op
     def logical_not(self, v):
@@ -1282,13 +1286,17 @@
     def _cbool(self, v):
         return bool(v[0]) or bool(v[1])
 
-    @raw_binary_op
+    @simple_binary_op
     def logical_and(self, v1, v2):
-        return self._cbool(v1) and self._cbool(v2)
+        if self._cbool(v1) and self._cbool(v2):
+            return Bool._True
+        return Bool._False
 
     @raw_binary_op
     def logical_or(self, v1, v2):
-        return self._cbool(v1) or self._cbool(v2)
+        if self._cbool(v1) or self._cbool(v2):
+            return Bool._True
+        return Bool._False
 
     @raw_unary_op
     def logical_not(self, v):
@@ -1811,14 +1819,14 @@
     @raw_binary_op
     def logical_and(self, v1, v2):
         if self._obool(v1):
-            return self.space.bool_w(v2)
-        return self.space.bool_w(v1)
+            return self.box(v2)
+        return self.box(v1)
 
     @raw_binary_op
     def logical_or(self, v1, v2):
         if self._obool(v1):
-            return self.space.bool_w(v1)
-        return self.space.bool_w(v2)
+            return self.box(v1)
+        return self.box(v2)
 
     @raw_unary_op
     def logical_not(self, v):
@@ -2062,11 +2070,15 @@
 
     @str_binary_op
     def logical_and(self, v1, v2):
-        return bool(v1) and bool(v2)
+        if bool(v1) and bool(v2):
+            return Bool._True
+        return Bool._False
 
     @str_binary_op
     def logical_or(self, v1, v2):
-        return bool(v1) or bool(v2)
+        if bool(v1) or bool(v2):
+            return Bool._True
+        return Bool._False
 
     @str_unary_op
     def logical_not(self, v):
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -209,7 +209,7 @@
                 axis += shapelen
         assert axis >= 0
         dtype = decode_w_dtype(space, dtype)
-        if self.comparison_func:
+        if self.bool_result:
             dtype = get_dtype_cache(space).w_booldtype
         elif dtype is None:
             dtype = find_unaryop_result_dtype(
@@ -395,19 +395,19 @@
 
 
 class W_Ufunc2(W_Ufunc):
-    _immutable_fields_ = ["func", "comparison_func", "done_func"]
+    _immutable_fields_ = ["func", "bool_result", "done_func"]
     nin = 2
     nout = 1
     nargs = 3
     signature = None
 
     def __init__(self, func, name, promote_to_largest=False, 
promote_to_float=False,
-            promote_bools=False, identity=None, comparison_func=False, 
int_only=False,
+            promote_bools=False, identity=None, bool_result=False, 
int_only=False,
             allow_bool=True, allow_complex=True, complex_to_float=False):
         W_Ufunc.__init__(self, name, promote_to_largest, promote_to_float, 
promote_bools,
                          identity, int_only, allow_bool, allow_complex, 
complex_to_float)
         self.func = func
-        self.comparison_func = comparison_func
+        self.bool_result = bool_result
         if name == 'logical_and':
             self.done_func = done_if_false
         elif name == 'logical_or':
@@ -439,20 +439,20 @@
         if w_ldtype.is_object() or w_rdtype.is_object():
             pass
         elif w_ldtype.is_str() and w_rdtype.is_str() and \
-                self.comparison_func:
+                self.bool_result:
             pass
         elif (w_ldtype.is_str()) and \
-                self.comparison_func and w_out is None:
+                self.bool_result and w_out is None:
             if self.name in ('equal', 'less_equal', 'less'):
                return space.wrap(False)
             return space.wrap(True) 
         elif (w_rdtype.is_str()) and \
-                self.comparison_func and w_out is None:
+                self.bool_result and w_out is None:
             if self.name in ('not_equal','less', 'less_equal'):
                return space.wrap(True)
             return space.wrap(False)
         elif w_ldtype.is_flexible() or w_rdtype.is_flexible():
-            if self.comparison_func:
+            if self.bool_result:
                 if self.name == 'equal' or self.name == 'not_equal':
                     res = w_ldtype.eq(space, w_rdtype)
                     if not res:
@@ -490,7 +490,7 @@
         else:
             out = w_out
             calc_dtype = out.get_dtype()
-        if self.comparison_func:
+        if self.bool_result:
             res_dtype = get_dtype_cache(space).w_booldtype
         else:
             res_dtype = calc_dtype
@@ -1121,8 +1121,7 @@
     #            'supported', w_obj)
 
 
-def ufunc_dtype_caller(space, ufunc_name, op_name, nin, comparison_func,
-                       bool_result):
+def ufunc_dtype_caller(space, ufunc_name, op_name, nin, bool_result):
     def get_op(dtype):
         try:
             return getattr(dtype.itemtype, op_name)
@@ -1140,7 +1139,7 @@
     elif nin == 2:
         def impl(res_dtype, lvalue, rvalue):
             res = get_op(res_dtype)(lvalue, rvalue)
-            if comparison_func:
+            if bool_result:
                 return dtype_cache.w_booldtype.box(res)
             return res
     return func_with_new_name(impl, ufunc_name)
@@ -1167,21 +1166,19 @@
             ("left_shift", "lshift", 2, {"int_only": True}),
             ("right_shift", "rshift", 2, {"int_only": True}),
 
-            ("equal", "eq", 2, {"comparison_func": True}),
-            ("not_equal", "ne", 2, {"comparison_func": True}),
-            ("less", "lt", 2, {"comparison_func": True}),
-            ("less_equal", "le", 2, {"comparison_func": True}),
-            ("greater", "gt", 2, {"comparison_func": True}),
-            ("greater_equal", "ge", 2, {"comparison_func": True}),
+            ("equal", "eq", 2, {"bool_result": True}),
+            ("not_equal", "ne", 2, {"bool_result": True}),
+            ("less", "lt", 2, {"bool_result": True}),
+            ("less_equal", "le", 2, {"bool_result": True}),
+            ("greater", "gt", 2, {"bool_result": True}),
+            ("greater_equal", "ge", 2, {"bool_result": True}),
             ("isnan", "isnan", 1, {"bool_result": True}),
             ("isinf", "isinf", 1, {"bool_result": True}),
             ("isfinite", "isfinite", 1, {"bool_result": True}),
 
-            ('logical_and', 'logical_and', 2, {'comparison_func': True,
-                                               'identity': 1}),
-            ('logical_or', 'logical_or', 2, {'comparison_func': True,
-                                             'identity': 0}),
-            ('logical_xor', 'logical_xor', 2, {'comparison_func': True}),
+            ('logical_and', 'logical_and', 2, {'identity': 1}),
+            ('logical_or', 'logical_or', 2, {'identity': 0}),
+            ('logical_xor', 'logical_xor', 2, {'bool_result': True}),
             ('logical_not', 'logical_not', 1, {'bool_result': True}),
 
             ("maximum", "max", 2),
@@ -1263,7 +1260,6 @@
         extra_kwargs["identity"] = identity
 
         func = ufunc_dtype_caller(space, ufunc_name, op_name, nin,
-            comparison_func=extra_kwargs.get("comparison_func", False),
             bool_result=extra_kwargs.get("bool_result", False),
         )
         if nin == 1:
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to