Author: Maciej Fijalkowski <[email protected]>
Branch: numpy-back-to-applevel
Changeset: r51887:891a2ea64919
Date: 2012-01-27 22:07 +0200
http://bitbucket.org/pypy/pypy/changeset/891a2ea64919/
Log: clean up scalar reshape and ravel
diff --git a/pypy/module/micronumpy/interp_numarray.py
b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -274,7 +274,8 @@
def descr_flatten(self, space, w_order=None):
if isinstance(self, Scalar):
- return self.copy(space)
+ # scalars have no storage
+ return self.descr_reshape(space, [space.wrap([1])])
concr = self.get_concrete()
w_res = concr.descr_ravel(space, w_order)
if w_res.storage == concr.storage:
@@ -479,8 +480,11 @@
w_shape = args_w[0]
else:
w_shape = space.newtuple(args_w)
+ new_shape = get_shape_from_iterable(space, self.size, w_shape)
+ return self.reshape(space, new_shape)
+
+ def reshape(self, space, new_shape):
concrete = self.get_concrete()
- new_shape = get_shape_from_iterable(space, concrete.size, w_shape)
# Since we got to here, prod(new_shape) == self.size
new_strides = calc_new_strides(new_shape, concrete.shape,
concrete.strides, concrete.order)
@@ -693,6 +697,11 @@
def get_concrete_or_scalar(self):
return self
+ def reshape(self, space, new_shape):
+ size = support.product(new_shape)
+ res = W_NDimArray(size, new_shape, self.dtype, 'C')
+ res.setitem(0, self.value)
+ return res
class VirtualArray(BaseArray):
"""
diff --git a/pypy/module/micronumpy/test/test_numarray.py
b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -469,6 +469,13 @@
y = z.reshape(4, 3, 8)
assert y.shape == (4, 3, 8)
+ def test_scalar_reshape(self):
+ from numpypy import array
+ a = array(3)
+ assert a.reshape([1, 1]).shape == (1, 1)
+ assert a.reshape([1]).shape == (1,)
+ raises(ValueError, "a.reshape(3)")
+
def test_add(self):
from _numpypy import array
a = array(range(5))
@@ -1104,6 +1111,7 @@
def test_flatten(self):
from _numpypy import array
+ assert array(3).flatten().shape == (1,)
a = array([[1, 2], [3, 4]])
b = a.flatten()
c = a.ravel()
_______________________________________________
pypy-commit mailing list
[email protected]
http://mail.python.org/mailman/listinfo/pypy-commit