Author: mattip Branch: numpypy-shape-bug Changeset: r51671:5a8fc969e644 Date: 2012-01-19 01:55 +0200 http://bitbucket.org/pypy/pypy/changeset/5a8fc969e644/
Log: add more tests diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py --- a/pypy/module/micronumpy/interp_numarray.py +++ b/pypy/module/micronumpy/interp_numarray.py @@ -183,10 +183,9 @@ n_old_elems_to_use *= old_shape[oldI] if n_new_elems_used == n_old_elems_to_use: oldI += 1 - if oldI >= len(old_shape): - break - cur_step = steps[oldI] - n_old_elems_to_use *= old_shape[oldI] + if oldI < len(old_shape): + cur_step = steps[oldI] + n_old_elems_to_use *= old_shape[oldI] elif order == 'C': for i in range(len(old_shape) - 1, -1, -1): steps.insert(0, old_strides[i] / last_step) @@ -206,10 +205,10 @@ n_old_elems_to_use *= old_shape[oldI] if n_new_elems_used == n_old_elems_to_use: oldI -= 1 - if oldI < -len(old_shape): - break - cur_step = steps[oldI] - n_old_elems_to_use *= old_shape[oldI] + if oldI >= -len(old_shape): + cur_step = steps[oldI] + n_old_elems_to_use *= old_shape[oldI] + assert len(new_strides) == len(new_shape) return new_strides class BaseArray(Wrappable): diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py --- a/pypy/module/micronumpy/test/test_numarray.py +++ b/pypy/module/micronumpy/test/test_numarray.py @@ -157,6 +157,13 @@ assert calc_new_strides([2, 3, 4], [8, 3], [1, 16], 'F') is None assert calc_new_strides([24], [2, 4, 3], [48, 6, 1], 'C') is None assert calc_new_strides([24], [2, 4, 3], [24, 6, 2], 'C') == [2] + assert calc_new_strides([105, 1], [3, 5, 7], [35, 7, 1],'C') == [1, 1] + assert calc_new_strides([1, 105], [3, 5, 7], [35, 7, 1],'C') == [105, 1] + assert calc_new_strides([1, 105], [3, 5, 7], [35, 7, 1],'F') is None + assert calc_new_strides([1, 1, 1, 105, 1], [15, 7], [7, 1],'C') == \ + [105, 105, 105, 1, 1] + assert calc_new_strides([1, 1, 105, 1, 1], [7, 15], [1, 7],'F') == \ + [1, 1, 1, 105, 105] class AppTestNumArray(BaseNumpyAppTest): @@ -767,7 +774,6 @@ assert (a[:, 1, :].sum(1) == [70, 315, 560]).all() raises (ValueError, 'a[:, 1, :].sum(2)') assert ((a + a).T.sum(2).T == (a + a).sum(0)).all() - skip("Those are broken, fix after removing Scalar") assert (a.reshape(1,-1).sum(0) == range(105)).all() assert (a.reshape(1,-1).sum(1) == 5460) _______________________________________________ pypy-commit mailing list pypy-commit@python.org http://mail.python.org/mailman/listinfo/pypy-commit