Author: Yichao Yu <yyc1...@gmail.com> Branch: numpy-generic-item Changeset: r74054:0703a6e2723e Date: 2014-09-23 12:29 -0400 http://bitbucket.org/pypy/pypy/changeset/0703a6e2723e/
Log: generic.swapaxes, improve test diff --git a/pypy/module/micronumpy/boxes.py b/pypy/module/micronumpy/boxes.py --- a/pypy/module/micronumpy/boxes.py +++ b/pypy/module/micronumpy/boxes.py @@ -416,6 +416,10 @@ self.w_flags = W_FlagsObject(self) return self.w_flags + @unwrap_spec(axis1=int, axis2=int) + def descr_swapaxes(self, space, axis1, axis2): + return self.item(space) + class W_BoolBox(W_GenericBox, PrimitiveBox): descr__new__, _get_dtype, descr_reduce = new_dtype_getter(NPY.BOOL) @@ -669,6 +673,7 @@ tostring = interp2app(W_GenericBox.descr_tostring), tobytes = interp2app(W_GenericBox.descr_tostring), reshape = interp2app(W_GenericBox.descr_reshape), + swapaxes = interp2app(W_GenericBox.descr_swapaxes), dtype = GetSetProperty(W_GenericBox.descr_get_dtype), size = GetSetProperty(W_GenericBox.descr_get_size), diff --git a/pypy/module/micronumpy/test/test_scalar.py b/pypy/module/micronumpy/test/test_scalar.py --- a/pypy/module/micronumpy/test/test_scalar.py +++ b/pypy/module/micronumpy/test/test_scalar.py @@ -301,13 +301,14 @@ def test_item_tolist(self): from numpypy import int8, int16, int32, int64, float32, float64 from numpypy import complex64, complex128 - for t in int8, int16, int32, int64: - val = t(17) - assert val == 17 - assert val.item() == 17 - assert val.tolist() == 17 - assert type(val.item()) == int - assert type(val.tolist()) == int + + def _do_test(np_type, py_type, orig_val, exp_val): + val = np_type(orig_val) + assert val == orig_val + assert val.item() == exp_val + assert val.tolist() == exp_val + assert type(val.item()) == py_type + assert type(val.tolist()) == py_type val.item(0) val.item(()) val.item((0,)) @@ -316,65 +317,54 @@ raises(TypeError, val.item, '') raises(IndexError, val.item, 2) + for t in int8, int16, int32, int64: + _do_test(t, int, 17, 17) + for t in float32, float64: - val = t(17) - assert val == 17 - assert val.item() == 17 - assert val.tolist() == 17 - assert type(val.item()) == float - assert type(val.tolist()) == float - val.item(0) - val.item(()) - val.item((0,)) - raises(ValueError, val.item, 0, 1) - raises(ValueError, val.item, 0, '') - raises(TypeError, val.item, '') - raises(IndexError, val.item, 2) + _do_test(t, float, 17, 17) for t in complex64, complex128: - val = t(17j) - assert val == 17j - assert val.item() == 17j - assert val.tolist() == 17j - assert type(val.item()) == complex - assert type(val.tolist()) == complex - val.item(0) - val.item(()) - val.item((0,)) - raises(ValueError, val.item, 0, 1) - raises(ValueError, val.item, 0, '') - raises(TypeError, val.item, '') - raises(IndexError, val.item, 2) + _do_test(t, complex, 17j, 17j) def test_transpose(self): from numpypy import int8, int16, int32, int64, float32, float64 from numpypy import complex64, complex128 - for t in int8, int16, int32, int64: - val = t(17) - assert val == 17 - assert val.transpose() == 17 - assert type(val.transpose()) == int + + def _do_test(np_type, py_type, orig_val, exp_val): + val = np_type(orig_val) + assert val == orig_val + assert val.transpose() == exp_val + assert type(val.transpose()) == py_type val.transpose(()) raises(ValueError, val.transpose, 0, 1) raises(TypeError, val.transpose, 0, '') raises(ValueError, val.transpose, 0) + for t in int8, int16, int32, int64: + _do_test(t, int, 17, 17) + for t in float32, float64: - val = t(17) - assert val == 17 - assert val.transpose() == 17 - assert type(val.transpose()) == float - val.transpose(()) - raises(ValueError, val.transpose, 0, 1) - raises(TypeError, val.transpose, 0, '') - raises(ValueError, val.transpose, 0) + _do_test(t, float, 17, 17) for t in complex64, complex128: - val = t(17j) - assert val == 17j - assert val.transpose() == 17j - assert type(val.transpose()) == complex - val.transpose(()) - raises(ValueError, val.transpose, 0, 1) - raises(TypeError, val.transpose, 0, '') - raises(ValueError, val.transpose, 0) + _do_test(t, complex, 17j, 17j) + + def test_swapaxes(self): + from numpypy import int8, int16, int32, int64, float32, float64 + from numpypy import complex64, complex128 + + def _do_test(np_type, py_type, orig_val, exp_val): + val = np_type(orig_val) + assert val == orig_val + assert val.swapaxes(10, 20) == exp_val + assert type(val.swapaxes(0, 1)) == py_type + raises(TypeError, val.swapaxes, 0, ()) + + for t in int8, int16, int32, int64: + _do_test(t, int, 17, 17) + + for t in float32, float64: + _do_test(t, float, 17, 17) + + for t in complex64, complex128: + _do_test(t, complex, 17j, 17j) _______________________________________________ pypy-commit mailing list pypy-commit@python.org https://mail.python.org/mailman/listinfo/pypy-commit