This is an automated email from the ASF dual-hosted git repository. zhreshold pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 73a9f1c Relaxing type requirements for slice_like op (#14097) 73a9f1c is described below commit 73a9f1ceb3d96b819472050514a59a5eae4baa92 Author: Przemyslaw Tredak <ptre...@gmail.com> AuthorDate: Thu Feb 14 11:15:10 2019 -0800 Relaxing type requirements for slice_like op (#14097) * Relaxing types for slice_like op * Added test * Fix typo in test * Fix lint --- src/operator/tensor/matrix_op.cc | 11 ++++++++++- tests/python/unittest/test_operator.py | 14 ++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index e5d354b..3a244ac 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -661,7 +661,16 @@ Example:: return std::vector<std::string>{"data", "shape_like"}; }) .set_attr<nnvm::FInferShape>("FInferShape", SliceLikeShape) -.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>) +.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs& attrs, + std::vector<int> *in_attrs, + std::vector<int> *out_attrs) { + CHECK_EQ(in_attrs->size(), 2) << " in operator " << attrs.name; + std::vector<int> checked_in_attrs = { (*in_attrs)[0] }; + bool ret = !type_is_none((*in_attrs)[1]) && + ElemwiseType<1, 1>(attrs, &checked_in_attrs, out_attrs); + (*in_attrs)[0] = checked_in_attrs[0]; + return ret; + }) .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_slice_like"}) .set_attr<FCompute>("FCompute<cpu>", SliceLikeForward<cpu>) .add_argument("data", "NDArray-or-Symbol", "Source input") diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1f42215..fc003b2 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -2516,6 +2516,20 @@ def test_slice_like(): assert_allclose(xgrad1.asnumpy(), mx.nd.zeros_like(xgrad1).asnumpy()) @with_seed() +def test_slice_like_different_types(): + x = [[ 1., 2., 3., 4.], + [ 5., 6., 7., 8.], + [ 9., 10., 11., 12.]] + + y = [[ 0., 0., 0.], + [ 0., 0., 0.]] + + x = mx.nd.array(x) + y = mx.nd.array(y).astype('int32') + z = mx.nd.slice_like(x, y) + assert_allclose(z.asnumpy(), [[1,2,3],[5,6,7]]) + +@with_seed() def test_flip(): for ndim in range(1, 6): for t in range(5):