Author: Brian Kearns <[email protected]>
Branch:
Changeset: r71116:891a19edd076
Date: 2014-04-30 23:51 -0400
http://bitbucket.org/pypy/pypy/changeset/891a19edd076/
Log: support ndarray.clip with only one of min or max
diff --git a/pypy/module/micronumpy/descriptor.py
b/pypy/module/micronumpy/descriptor.py
--- a/pypy/module/micronumpy/descriptor.py
+++ b/pypy/module/micronumpy/descriptor.py
@@ -29,9 +29,11 @@
if not space.is_none(out):
return out
- dtype = w_arr_list[0].get_dtype()
- for w_arr in w_arr_list[1:]:
- dtype = find_binop_result_dtype(space, dtype, w_arr.get_dtype())
+ dtype = None
+ for w_arr in w_arr_list:
+ if not space.is_none(w_arr):
+ dtype = find_binop_result_dtype(space, dtype, w_arr.get_dtype())
+ assert dtype is not None
out = base.W_NDimArray.from_shape(space, shape, dtype)
return out
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -606,25 +606,34 @@
reds = 'auto')
def clip(space, arr, shape, min, max, out):
+ assert min or max
arr_iter, arr_state = arr.create_iter(shape)
+ if min is not None:
+ min_iter, min_state = min.create_iter(shape)
+ else:
+ min_iter, min_state = None, None
+ if max is not None:
+ max_iter, max_state = max.create_iter(shape)
+ else:
+ max_iter, max_state = None, None
+ out_iter, out_state = out.create_iter(shape)
+ shapelen = len(shape)
dtype = out.get_dtype()
- shapelen = len(shape)
- min_iter, min_state = min.create_iter(shape)
- max_iter, max_state = max.create_iter(shape)
- out_iter, out_state = out.create_iter(shape)
while not arr_iter.done(arr_state):
clip_driver.jit_merge_point(shapelen=shapelen, dtype=dtype)
w_v = arr_iter.getitem(arr_state).convert_to(space, dtype)
- w_min = min_iter.getitem(min_state).convert_to(space, dtype)
- w_max = max_iter.getitem(max_state).convert_to(space, dtype)
- if dtype.itemtype.lt(w_v, w_min):
- w_v = w_min
- elif dtype.itemtype.gt(w_v, w_max):
- w_v = w_max
+ arr_state = arr_iter.next(arr_state)
+ if min_iter is not None:
+ w_min = min_iter.getitem(min_state).convert_to(space, dtype)
+ if dtype.itemtype.lt(w_v, w_min):
+ w_v = w_min
+ min_state = min_iter.next(min_state)
+ if max_iter is not None:
+ w_max = max_iter.getitem(max_state).convert_to(space, dtype)
+ if dtype.itemtype.gt(w_v, w_max):
+ w_v = w_max
+ max_state = max_iter.next(max_state)
out_iter.setitem(out_state, w_v)
- arr_state = arr_iter.next(arr_state)
- min_state = min_iter.next(min_state)
- max_state = max_iter.next(max_state)
out_state = out_iter.next(out_state)
round_driver = jit.JitDriver(name='numpy_round_driver',
diff --git a/pypy/module/micronumpy/ndarray.py
b/pypy/module/micronumpy/ndarray.py
--- a/pypy/module/micronumpy/ndarray.py
+++ b/pypy/module/micronumpy/ndarray.py
@@ -593,17 +593,25 @@
def descr_choose(self, space, w_choices, w_out=None, w_mode=None):
return choose(space, self, w_choices, w_out, w_mode)
- def descr_clip(self, space, w_min, w_max, w_out=None):
+ def descr_clip(self, space, w_min=None, w_max=None, w_out=None):
+ if space.is_none(w_min):
+ w_min = None
+ else:
+ w_min = convert_to_array(space, w_min)
+ if space.is_none(w_max):
+ w_max = None
+ else:
+ w_max = convert_to_array(space, w_max)
if space.is_none(w_out):
w_out = None
elif not isinstance(w_out, W_NDimArray):
raise OperationError(space.w_TypeError, space.wrap(
"return arrays must be of ArrayType"))
- min = convert_to_array(space, w_min)
- max = convert_to_array(space, w_max)
- shape = shape_agreement_multiple(space, [self, min, max, w_out])
- out = descriptor.dtype_agreement(space, [self, min, max], shape, w_out)
- loop.clip(space, self, shape, min, max, out)
+ if not w_min and not w_max:
+ raise oefmt(space.w_ValueError, "One of max or min must be given.")
+ shape = shape_agreement_multiple(space, [self, w_min, w_max, w_out])
+ out = descriptor.dtype_agreement(space, [self, w_min, w_max], shape,
w_out)
+ loop.clip(space, self, shape, w_min, w_max, out)
return out
def descr_get_ctypes(self, space):
diff --git a/pypy/module/micronumpy/test/test_ndarray.py
b/pypy/module/micronumpy/test/test_ndarray.py
--- a/pypy/module/micronumpy/test/test_ndarray.py
+++ b/pypy/module/micronumpy/test/test_ndarray.py
@@ -2229,7 +2229,13 @@
def test_clip(self):
from numpypy import array
a = array([1, 2, 17, -3, 12])
+ exc = raises(ValueError, a.clip)
+ assert str(exc.value) == "One of max or min must be given."
assert (a.clip(-2, 13) == [1, 2, 13, -2, 12]).all()
+ assert (a.clip(min=-2) == [1, 2, 17, -2, 12]).all()
+ assert (a.clip(min=-2, max=None) == [1, 2, 17, -2, 12]).all()
+ assert (a.clip(max=13) == [1, 2, 13, -3, 12]).all()
+ assert (a.clip(min=None, max=13) == [1, 2, 13, -3, 12]).all()
assert (a.clip(-1, 1, out=None) == [1, 1, 1, -1, 1]).all()
assert (a == [1, 2, 17, -3, 12]).all()
assert (a.clip(-1, [1, 2, 3, 4, 5]) == [1, 2, 3, -1, 5]).all()
diff --git a/pypy/module/micronumpy/ufuncs.py b/pypy/module/micronumpy/ufuncs.py
--- a/pypy/module/micronumpy/ufuncs.py
+++ b/pypy/module/micronumpy/ufuncs.py
@@ -477,6 +477,8 @@
promote_bools=False):
if dt2 is None:
return dt1
+ if dt1 is None:
+ return dt2
# dt1.num should be <= dt2.num
if dt1.num > dt2.num:
dt1, dt2 = dt2, dt1
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit