piiswrong closed pull request #10417: [MXNET-283] Error handling for non-positive reps of tile op URL: https://github.com/apache/incubator-mxnet/pull/10417
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 9aaac5f4a32..7756119ba3f 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -1442,7 +1442,8 @@ struct TileParam : public dmlc::Parameter<TileParam> { TShape reps; DMLC_DECLARE_PARAMETER(TileParam) { DMLC_DECLARE_FIELD(reps) - .describe("The number of times for repeating the tensor a." + .describe("The number of times for repeating the tensor a. Each dim size of reps" + " must be a positive integer." " If reps has length d, the result will have dimension of max(d, a.ndim);" " If a.ndim < d, a is promoted to be d-dimensional by prepending new axes." " If a.ndim > d, reps is promoted to a.ndim by pre-pending 1's to it."); @@ -1462,6 +1463,9 @@ inline bool TileOpShape(const nnvm::NodeAttrs& attrs, SHAPE_ASSIGN_CHECK(*out_attrs, 0, ishape); return true; } + for (size_t i = 0; i < reps.ndim(); ++i) { + CHECK_GT(reps[i], 0) << "invalid reps=" << i << ", dim size must be greater than zero"; + } TShape oshape(std::max(ishape.ndim(), reps.ndim())); int i1 = static_cast<int>(ishape.ndim()) - 1; int i2 = static_cast<int>(reps.ndim()) - 1; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 629304da533..e92ac999972 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -24,7 +24,7 @@ import itertools from numpy.testing import assert_allclose, assert_array_equal from mxnet.test_utils import * -from mxnet.base import py_str +from mxnet.base import py_str, MXNetError from common import setup_module, with_seed import unittest @@ -3480,24 +3480,23 @@ def test_reverse(): @with_seed() def test_tile(): def test_normal_case(): - ndim_max = 3 # max number of dims of the ndarray - size_max = 10 # max number of elements in each dim - length_max = 3 # max length of reps - rep_max = 10 # max number of tiling in each dim - for ndim in range(ndim_max, ndim_max+1): - shape = () - for i in range(0, ndim): - shape += (np.random.randint(1, size_max+1), ) + ndim_min = 1 + ndim_max = 5 # max number of dims of the ndarray + size_max = 10 # max number of elements in each dim + length_max = 3 # max length of reps + rep_max = 10 # max number of tiling in each dim + for ndim in range(ndim_min, ndim_max+1): + shape = [] + for i in range(1, ndim+1): + shape.append(np.random.randint(1, size_max+1)) + shape = tuple(shape) a = np.random.randint(0, 100, shape) - a = np.asarray(a, dtype=np.int32) - if ndim == 0: - a = np.array([]) - b = mx.nd.array(a, ctx=default_context(), dtype=a.dtype) + b = mx.nd.array(a, dtype=a.dtype) - reps_len = np.random.randint(0, length_max+1) + reps_len = np.random.randint(1, length_max+1) reps_tuple = () for i in range(1, reps_len): - reps_tuple += (np.random.randint(0, rep_max), ) + reps_tuple += (np.random.randint(1, rep_max), ) reps_array = np.asarray(reps_tuple) a_tiled = np.tile(a, reps_array) @@ -3521,14 +3520,6 @@ def test_empty_reps(): b_tiled = mx.nd.tile(b, ()).asnumpy() assert same(a_tiled, b_tiled) - def test_zero_reps(): - a = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32) - b = mx.nd.array(a, ctx=default_context(), dtype=a.dtype) - reps = (2, 0, 4, 5) - a_tiled = np.tile(a, reps) - b_tiled = mx.nd.tile(b, reps).asnumpy() - assert same(a_tiled, b_tiled) - def test_tile_backward(): data = mx.sym.Variable('data') n1 = 2 @@ -3565,12 +3556,17 @@ def test_tile_numeric_gradient(): test = mx.sym.tile(data, reps=reps) check_numeric_gradient(test, [data_tmp], numeric_eps=1e-2, rtol=1e-2) + def test_invalid_reps(): + data = mx.nd.arange(16).reshape((4, 4)) + assert_exception(mx.nd.tile, MXNetError, data, (1, 2, -3)) + assert_exception(mx.nd.tile, MXNetError, data, (1, 0, 3)) + test_normal_case() test_empty_tensor() test_empty_reps() - test_zero_reps() test_tile_backward() test_tile_numeric_gradient() + test_invalid_reps() @with_seed() ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services