Author: Romain Guillebert <[email protected]>
Branch: 
Changeset: r67202:2e8639dfd82e
Date: 2013-10-08 14:50 +0200
http://bitbucket.org/pypy/pypy/changeset/2e8639dfd82e/

Log:    Fix pypy issue 1598

diff --git a/pypy/module/micronumpy/interp_ufuncs.py 
b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -319,6 +319,15 @@
         else:
             self.done_func = None
 
+    def are_common_types(self, dtype1, dtype2):
+        if dtype1.is_complex_type() and dtype2.is_complex_type():
+            return True
+        elif not (dtype1.is_complex_type() or dtype2.is_complex_type()) and \
+                (dtype1.is_int_type() and dtype2.is_int_type() or 
dtype1.is_float_type() and dtype2.is_float_type()) and \
+                not (dtype1.is_bool_type() or dtype2.is_bool_type()):
+            return True
+        return False
+
     @jit.unroll_safe
     def call(self, space, args_w):
         if len(args_w) > 2:
@@ -339,6 +348,12 @@
                  'unsupported operand dtypes %s and %s for "%s"' % \
                  (w_rdtype.get_name(), w_ldtype.get_name(),
                   self.name)))
+
+        if self.are_common_types(w_ldtype, w_rdtype):
+            if not w_lhs.is_scalar() and w_rhs.is_scalar():
+                w_rdtype = w_ldtype
+            elif w_lhs.is_scalar() and not w_rhs.is_scalar():
+                w_ldtype = w_rdtype
         calc_dtype = find_binop_result_dtype(space,
             w_ldtype, w_rdtype,
             int_only=self.int_only,
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -2971,6 +2971,11 @@
                         dtype=[('bg', 'i8'), ('fg', 'i8'), ('char', 'S1')])
         assert c[0][0]["char"] == 'a'
 
+    def test_scalar_coercion(self):
+        import numpypy as np
+        a = np.array([1,2,3], dtype=np.int16)
+        assert (a * 2).dtype == np.int16
+
 class AppTestPyPy(BaseNumpyAppTest):
     def setup_class(cls):
         if option.runappdirect and '__pypy__' not in sys.builtin_module_names:
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to