Author: Maciej Fijalkowski <fij...@gmail.com>
Branch: missing-ndarray-attributes
Changeset: r58584:47b57e79e2fa
Date: 2012-10-29 14:59 +0100
http://bitbucket.org/pypy/pypy/changeset/47b57e79e2fa/

Log:    implement ndarray.choose

diff --git a/pypy/module/micronumpy/arrayimpl/concrete.py 
b/pypy/module/micronumpy/arrayimpl/concrete.py
--- a/pypy/module/micronumpy/arrayimpl/concrete.py
+++ b/pypy/module/micronumpy/arrayimpl/concrete.py
@@ -13,12 +13,6 @@
 from pypy.module.micronumpy.arrayimpl.sort import argsort_array
 from pypy.rlib.debug import make_sure_not_resized
 
-def int_w(space, w_obj):
-    try:
-        return space.int_w(space.index(w_obj))
-    except OperationError:
-        return space.int_w(space.int(w_obj))
-
 class BaseConcreteArray(base.BaseArrayImplementation):
     start = 0
     parent = None
@@ -85,7 +79,7 @@
         for i, w_index in enumerate(view_w):
             if space.isinstance_w(w_index, space.w_slice):
                 raise IndexError
-            idx = int_w(space, w_index)
+            idx = support.int_w(space, w_index)
             if idx < 0:
                 idx = self.get_shape()[i] + idx
             if idx < 0 or idx >= self.get_shape()[i]:
@@ -159,7 +153,7 @@
             return self._lookup_by_index(space, view_w)
         if shape_len > 1:
             raise IndexError
-        idx = int_w(space, w_idx)
+        idx = support.int_w(space, w_idx)
         return self._lookup_by_index(space, [space.wrap(idx)])
 
     @jit.unroll_safe
diff --git a/pypy/module/micronumpy/constants.py 
b/pypy/module/micronumpy/constants.py
new file mode 100644
--- /dev/null
+++ b/pypy/module/micronumpy/constants.py
@@ -0,0 +1,4 @@
+
+MODE_WRAP, MODE_RAISE, MODE_CLIP = range(3)
+
+MODES = {'wrap': MODE_WRAP, 'raise': MODE_RAISE, 'clip': MODE_CLIP}
diff --git a/pypy/module/micronumpy/interp_arrayops.py 
b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -3,6 +3,7 @@
 from pypy.module.micronumpy import loop, interp_ufuncs
 from pypy.module.micronumpy.iter import Chunk, Chunks
 from pypy.module.micronumpy.strides import shape_agreement
+from pypy.module.micronumpy.constants import MODES
 from pypy.interpreter.error import OperationError, operationerrfmt
 from pypy.interpreter.gateway import unwrap_spec
 
@@ -153,3 +154,28 @@
 
 def count_nonzero(space, w_obj):
     return space.wrap(loop.count_all_true(convert_to_array(space, w_obj)))
+
+def choose(space, arr, w_choices, out, mode):
+    choices = [convert_to_array(space, w_item) for w_item
+               in space.listview(w_choices)]
+    if not choices:
+        raise OperationError(space.w_ValueError,
+                             space.wrap("choices list cannot be empty"))
+    # find the shape agreement
+    shape = arr.get_shape()
+    for choice in choices:
+        shape = shape_agreement(space, shape, choice)
+    if out is not None:
+        shape = shape_agreement(space, shape, out)
+    # find the correct dtype
+    dtype = choices[0].get_dtype()
+    for choice in choices[1:]:
+        dtype = interp_ufuncs.find_binop_result_dtype(space,
+                                                      dtype, 
choice.get_dtype())
+    if out is None:
+        out = W_NDimArray.from_shape(shape, dtype)
+    if mode not in MODES:
+        raise OperationError(space.w_ValueError,
+                             space.wrap("mode %s not known" % (mode,)))
+    loop.choose(space, arr, choices, shape, dtype, out, MODES[mode])
+    return out
diff --git a/pypy/module/micronumpy/interp_numarray.py 
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -4,7 +4,8 @@
 from pypy.interpreter.gateway import interp2app, unwrap_spec, WrappedDefault
 from pypy.module.micronumpy.base import W_NDimArray, convert_to_array,\
      ArrayArgumentException
-from pypy.module.micronumpy import interp_dtype, interp_ufuncs, interp_boxes
+from pypy.module.micronumpy import interp_dtype, interp_ufuncs, interp_boxes,\
+     interp_arrayops
 from pypy.module.micronumpy.strides import find_shape_and_elems,\
      get_shape_from_iterable, to_coords, shape_agreement
 from pypy.module.micronumpy.interp_flatiter import W_FlatIterator
@@ -402,9 +403,11 @@
             return res
 
     @unwrap_spec(mode=str)
-    def descr_choose(self, space, w_choices, w_out=None, mode='raise'):
-        raise OperationError(space.w_NotImplementedError, space.wrap(
-            "choose not implemented yet"))
+    def descr_choose(self, space, w_choices, mode='raise', w_out=None):
+        if w_out is not None and not isinstance(w_out, W_NDimArray):
+            raise OperationError(space.w_TypeError, space.wrap(
+                "return arrays must be of ArrayType"))
+        return interp_arrayops.choose(space, self, w_choices, w_out, mode)
 
     def descr_clip(self, space, w_min, w_max, w_out=None):
         raise OperationError(space.w_NotImplementedError, space.wrap(
diff --git a/pypy/module/micronumpy/loop.py b/pypy/module/micronumpy/loop.py
--- a/pypy/module/micronumpy/loop.py
+++ b/pypy/module/micronumpy/loop.py
@@ -4,11 +4,14 @@
 over all the array elements.
 """
 
+from pypy.interpreter.error import OperationError
 from pypy.rlib.rstring import StringBuilder
 from pypy.rlib import jit
 from pypy.rpython.lltypesystem import lltype, rffi
 from pypy.module.micronumpy.base import W_NDimArray
 from pypy.module.micronumpy.iter import PureShapeIterator
+from pypy.module.micronumpy import constants
+from pypy.module.micronumpy.support import int_w
 
 call2_driver = jit.JitDriver(name='numpy_call2',
                              greens = ['shapelen', 'func', 'calc_dtype',
@@ -486,3 +489,36 @@
         to_iter.setitem(dtype.itemtype.byteswap(from_iter.getitem()))
         to_iter.next()
         from_iter.next()
+
+choose_driver = jit.JitDriver(greens = ['shapelen', 'mode', 'dtype'],
+                              reds = ['shape', 'iterators', 'arr_iter',
+                                      'out_iter'])
+
+def choose(space, arr, choices, shape, dtype, out, mode):
+    shapelen = len(shape)
+    iterators = [a.create_iter(shape) for a in choices]
+    arr_iter = arr.create_iter(shape)
+    out_iter = out.create_iter(shape)
+    while not arr_iter.done():
+        choose_driver.jit_merge_point(shapelen=shapelen, dtype=dtype,
+                                      mode=mode, shape=shape,
+                                      iterators=iterators, arr_iter=arr_iter,
+                                      out_iter=out_iter)
+        index = int_w(space, arr_iter.getitem())
+        if index < 0 or index >= len(iterators):
+            if mode == constants.MODE_RAISE:
+                raise OperationError(space.w_ValueError, space.wrap(
+                    "invalid entry in choice array"))
+            elif mode == constants.MODE_WRAP:
+                index = index % (len(iterators))
+            else:
+                assert mode == constants.MODE_CLIP
+                if index < 0:
+                    index = 0
+                else:
+                    index = len(iterators) - 1
+        out_iter.setitem(iterators[index].getitem().convert_to(dtype))
+        for iter in iterators:
+            iter.next()
+        out_iter.next()
+        arr_iter.next()
diff --git a/pypy/module/micronumpy/support.py 
b/pypy/module/micronumpy/support.py
--- a/pypy/module/micronumpy/support.py
+++ b/pypy/module/micronumpy/support.py
@@ -1,5 +1,11 @@
 from pypy.rlib import jit
+from pypy.interpreter.error import OperationError
 
+def int_w(space, w_obj):
+    try:
+        return space.int_w(space.index(w_obj))
+    except OperationError:
+        return space.int_w(space.int(w_obj))
 
 @jit.unroll_safe
 def product(s):
diff --git a/pypy/module/micronumpy/test/test_arrayops.py 
b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -108,8 +108,16 @@
         a, b, c = array([1, 2, 3]), [4, 5, 6], 13
         raises(ValueError, "array([3, 1, 0]).choose([a, b, c])")
         raises(ValueError, "array([3, 1, 0]).choose([a, b, c], 'raises')")
+        raises(ValueError, "array([3, 1, 0]).choose([])")
+        raises(ValueError, "array([-1, -2, -3]).choose([a, b, c])")
         r = array([4, 1, 0]).choose([a, b, c], mode='clip')
         assert (r == [13, 5, 3]).all()
         r = array([4, 1, 0]).choose([a, b, c], mode='wrap')
         assert (r == [4, 5, 3]).all()
-        
+
+
+    def test_choose_dtype(self):
+        from _numpypy import array
+        a, b, c = array([1.2, 2, 3]), [4, 5, 6], 13
+        r = array([2, 1, 0]).choose([a, b, c])
+        assert r.dtype == float
_______________________________________________
pypy-commit mailing list
pypy-commit@python.org
http://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to