Author: Brian Kearns <[email protected]>
Branch: numpy-refactor
Changeset: r69480:568a574ba925
Date: 2014-02-26 18:18 -0500
http://bitbucket.org/pypy/pypy/changeset/568a574ba925/

Log:    fix newaxis with scalars

diff --git a/pypy/module/micronumpy/concrete.py 
b/pypy/module/micronumpy/concrete.py
--- a/pypy/module/micronumpy/concrete.py
+++ b/pypy/module/micronumpy/concrete.py
@@ -205,10 +205,10 @@
                 raise OperationError(space.w_ValueError, space.wrap(
                     "field named %s not found" % idx))
             return RecordChunk(idx)
-        if len(self.get_shape()) == 0:
-            raise oefmt(space.w_ValueError, "cannot slice a 0-d array")
-        if (space.isinstance_w(w_idx, space.w_int) or
+        elif (space.isinstance_w(w_idx, space.w_int) or
                 space.isinstance_w(w_idx, space.w_slice)):
+            if len(self.get_shape()) == 0:
+                raise oefmt(space.w_ValueError, "cannot slice a 0-d array")
             return Chunks([Chunk(*space.decode_index4(w_idx, 
self.get_shape()[0]))])
         elif isinstance(w_idx, W_NDimArray) and w_idx.is_scalar():
             w_idx = w_idx.get_scalar_value().item(space)
diff --git a/pypy/module/micronumpy/strides.py 
b/pypy/module/micronumpy/strides.py
--- a/pypy/module/micronumpy/strides.py
+++ b/pypy/module/micronumpy/strides.py
@@ -27,12 +27,16 @@
     i = -1
     j = 0
     for i, chunk in enumerate_chunks(chunks):
+        try:
+            s_i = strides[i]
+        except IndexError:
+            continue
         if chunk.step != 0:
-            rstrides[j] = strides[i] * chunk.step
-            rbackstrides[j] = strides[i] * max(0, chunk.lgt - 1) * chunk.step
+            rstrides[j] = s_i * chunk.step
+            rbackstrides[j] = s_i * max(0, chunk.lgt - 1) * chunk.step
             rshape[j] = chunk.lgt
             j += 1
-        rstart += strides[i] * chunk.start
+        rstart += s_i * chunk.start
     # add a reminder
     s = i + 1
     assert s >= 0
diff --git a/pypy/module/micronumpy/test/test_numarray.py 
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -716,9 +716,14 @@
             for y in range(2):
                 expected[x, y] = math.cos(a[x]) * math.cos(b[y])
         assert ((cos(a)[:,newaxis] * cos(b).T) == expected).all()
-        a = array(1)[newaxis]
+        o = array(1)
+        a = o[newaxis]
         assert a == array([1])
         assert a.shape == (1,)
+        o[newaxis, newaxis] = 2
+        assert o == 2
+        a[:] = 3
+        assert o == 3
 
     def test_newaxis_slice(self):
         from numpypy import array, newaxis
_______________________________________________
pypy-commit mailing list
[email protected]
https://mail.python.org/mailman/listinfo/pypy-commit

Reply via email to