Author: Brian Kearns <[email protected]>
Branch:
Changeset: r67833:1fb8f989bf6a
Date: 2013-11-04 13:43 -0500
http://bitbucket.org/pypy/pypy/changeset/1fb8f989bf6a/
Log: fix min/max of complex with nans
diff --git a/pypy/module/micronumpy/test/dummy_module.py
b/pypy/module/micronumpy/test/dummy_module.py
--- a/pypy/module/micronumpy/test/dummy_module.py
+++ b/pypy/module/micronumpy/test/dummy_module.py
@@ -1,6 +1,7 @@
from _numpypy.multiarray import *
from _numpypy.umath import *
+nan = float('nan')
newaxis = None
ufunc = type(sin)
diff --git a/pypy/module/micronumpy/test/test_ufuncs.py
b/pypy/module/micronumpy/test/test_ufuncs.py
--- a/pypy/module/micronumpy/test/test_ufuncs.py
+++ b/pypy/module/micronumpy/test/test_ufuncs.py
@@ -319,6 +319,27 @@
assert x == 3
assert isinstance(x, (int, long))
+ def test_complex_nan_extrema(self):
+ import math
+ import numpy as np
+ cnan = complex(0, np.nan)
+
+ b = np.minimum(1, cnan)
+ assert b.real == 0
+ assert math.isnan(b.imag)
+
+ b = np.maximum(1, cnan)
+ assert b.real == 0
+ assert math.isnan(b.imag)
+
+ b = np.fmin(1, cnan)
+ assert b.real == 1
+ assert b.imag == 0
+
+ b = np.fmax(1, cnan)
+ assert b.real == 1
+ assert b.imag == 0
+
def test_multiply(self):
from numpypy import array, multiply, arange
diff --git a/pypy/module/micronumpy/types.py b/pypy/module/micronumpy/types.py
--- a/pypy/module/micronumpy/types.py
+++ b/pypy/module/micronumpy/types.py
@@ -1192,11 +1192,11 @@
def _lt(self, v1, v2):
(r1, i1), (r2, i2) = v1, v2
- if r1 < r2:
+ if r1 < r2 and not rfloat.isnan(i1) and not rfloat.isnan(i2):
return True
- elif not r1 <= r2:
- return False
- return i1 < i2
+ if r1 == r2 and i1 < i2:
+ return True
+ return False
@raw_binary_op
def lt(self, v1, v2):
@@ -1234,10 +1234,14 @@
return self._bool(v1) ^ self._bool(v2)
def min(self, v1, v2):
- return self.fmin(v1, v2)
+ if self.le(v1, v2) or self.isnan(v1):
+ return v1
+ return v2
def max(self, v1, v2):
- return self.fmax(v1, v2)
+ if self.ge(v1, v2) or self.isnan(v1):
+ return v1
+ return v2
@complex_binary_op
def floordiv(self, v1, v2):
@@ -1292,20 +1296,12 @@
return -1,0
def fmax(self, v1, v2):
- if self.isnan(v2):
- return v1
- elif self.isnan(v1):
- return v2
- if self.ge(v1, v2):
+ if self.ge(v1, v2) or self.isnan(v2):
return v1
return v2
def fmin(self, v1, v2):
- if self.isnan(v2):
- return v1
- elif self.isnan(v1):
- return v2
- if self.le(v1, v2):
+ if self.le(v1, v2) or self.isnan(v2):
return v1
return v2
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit