Author: Brian Kearns <[email protected]>
Branch: 
Changeset: r67833:1fb8f989bf6a
Date: 2013-11-04 13:43 -0500
http://bitbucket.org/pypy/pypy/changeset/1fb8f989bf6a/

Log:    fix min/max of complex with nans

diff --git a/pypy/module/micronumpy/test/dummy_module.py 
b/pypy/module/micronumpy/test/dummy_module.py
--- a/pypy/module/micronumpy/test/dummy_module.py
+++ b/pypy/module/micronumpy/test/dummy_module.py
@@ -1,6 +1,7 @@
 from _numpypy.multiarray import *
 from _numpypy.umath import *
 
+nan = float('nan')
 newaxis = None
 ufunc = type(sin)
 
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py 
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -319,6 +319,27 @@
         assert x == 3
         assert isinstance(x, (int, long))
 
+    def test_complex_nan_extrema(self):
+        import math
+        import numpy as np
+        cnan = complex(0, np.nan)
+
+        b = np.minimum(1, cnan)
+        assert b.real == 0
+        assert math.isnan(b.imag)
+
+        b = np.maximum(1, cnan)
+        assert b.real == 0
+        assert math.isnan(b.imag)
+
+        b = np.fmin(1, cnan)
+        assert b.real == 1
+        assert b.imag == 0
+
+        b = np.fmax(1, cnan)
+        assert b.real == 1
+        assert b.imag == 0
+
     def test_multiply(self):
         from numpypy import array, multiply, arange
 
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -1192,11 +1192,11 @@
 
     def _lt(self, v1, v2):
         (r1, i1), (r2, i2) = v1, v2
-        if r1 < r2:
+        if r1 < r2 and not rfloat.isnan(i1) and not rfloat.isnan(i2):
             return True
-        elif not r1 <= r2:
-            return False
-        return i1 < i2
+        if r1 == r2 and i1 < i2:
+            return True
+        return False
 
     @raw_binary_op
     def lt(self, v1, v2):
@@ -1234,10 +1234,14 @@
         return self._bool(v1) ^ self._bool(v2)
 
     def min(self, v1, v2):
-        return self.fmin(v1, v2)
+        if self.le(v1, v2) or self.isnan(v1):
+            return v1
+        return v2
 
     def max(self, v1, v2):
-        return self.fmax(v1, v2)
+        if self.ge(v1, v2) or self.isnan(v1):
+            return v1
+        return v2
 
     @complex_binary_op
     def floordiv(self, v1, v2):
@@ -1292,20 +1296,12 @@
         return -1,0
 
     def fmax(self, v1, v2):
-        if self.isnan(v2):
-            return v1
-        elif self.isnan(v1):
-            return v2
-        if self.ge(v1, v2):
+        if self.ge(v1, v2) or self.isnan(v2):
             return v1
         return v2
 
     def fmin(self, v1, v2):
-        if self.isnan(v2):
-            return v1
-        elif self.isnan(v1):
-            return v2
-        if self.le(v1, v2):
+        if self.le(v1, v2) or self.isnan(v2):
             return v1
         return v2
 
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to