eric-haibin-lin closed pull request #13599: fallback to dense version for
grad(reshape), grad(expand_dims)
URL: https://github.com/apache/incubator-mxnet/pull/13599
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc
b/src/operator/tensor/elemwise_unary_op_basic.cc
index 9730d0096e5..7f69395d1c8 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cc
+++ b/src/operator/tensor/elemwise_unary_op_basic.cc
@@ -236,6 +236,20 @@ NNVM_REGISTER_OP(_backward_copy)
return std::vector<bool>{true};
});
+NNVM_REGISTER_OP(_backward_reshape)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+ [](const NodeAttrs& attrs){
+ return std::vector<std::pair<int, int> >{{0,
0}};
+ })
+.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
+.set_attr<nnvm::FInplaceIdentity>("FInplaceIdentity",
+ [](const NodeAttrs& attrs){
+ return std::vector<bool>{true};
+ });
+
MXNET_OPERATOR_REGISTER_UNARY(BlockGrad)
MXNET_ADD_SPARSE_OP_ALIAS(stop_gradient)
.add_alias("stop_gradient")
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu
b/src/operator/tensor/elemwise_unary_op_basic.cu
index c28934e9465..14f2be02ab1 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cu
+++ b/src/operator/tensor/elemwise_unary_op_basic.cu
@@ -68,6 +68,10 @@ NNVM_REGISTER_OP(_copy)
.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::IdentityComputeEx<gpu>);
NNVM_REGISTER_OP(_backward_copy)
+.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>)
+.set_attr<FComputeEx>("FComputeEx<gpu>", UnaryOp::IdentityComputeEx<gpu>);
+
+NNVM_REGISTER_OP(_backward_reshape)
.set_attr<FCompute>("FCompute<gpu>", UnaryOp::IdentityCompute<gpu>);
NNVM_REGISTER_OP(BlockGrad)
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index 2ffeabc11ae..db8efa45438 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -223,7 +223,7 @@ If the argument `reverse` is set to 1, then the special
values are inferred from
.set_attr_parser(ParamParser<ReshapeParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ReshapeShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_copy"})
+.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseNone{"_backward_reshape"})
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
@@ -415,7 +415,7 @@ will return a new array with shape ``(2,1,3,4)``.
[](const NodeAttrs& attrs){
return std::vector<bool>{true};
})
-.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_copy"})
+.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseNone{"_backward_reshape"})
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::IdentityCompute<cpu>)
.add_argument("data", "NDArray-or-Symbol", "Source input")
.add_arguments(ExpandDimParam::__FIELDS__());
diff --git a/tests/python/unittest/test_sparse_operator.py
b/tests/python/unittest/test_sparse_operator.py
index 57808248b08..05175bb435f 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -2306,6 +2306,48 @@ def check_sparse_quadratic_function(a, b, c,
expected_stype):
check_sparse_quadratic_function(a, b, 0.0, 'csr')
check_sparse_quadratic_function(a, b, 1.0, 'default')
+def test_reshape_backward_fallback():
+ """
+ out
+ | \
+ w_x x
+ /
+ w
+ in which x is a sparse tensor.
+ Due to sparse gradient optimization in sym.dot, grad(w_x) is sparse.
+ Though sym.reshape itself does not have sparse version,
+ if we somehow make grad(w) sparse as well, e.g.,
+ - by setting args_grad in symbol.bind
+ - or, we can have out_y = sym.dot(sparse_y, w), then grad(w) will be
inferred as sparse
+ reshape backward (from w_x to w) needs to understand how to handle sparse
inputs.
+ """
+ ctx = default_context()
+ w_shape = (12, 4)
+ w_x_shape = (1, 48)
+ x_nd = rand_ndarray((4, 1), 'csr')
+
+ w_nd = rand_ndarray(w_shape)
+
+ w_x_nd = w_nd.reshape(w_x_shape)
+ out_x_nd = mx.nd.dot(x_nd, w_x_nd)
+
+ w_x_backward_grad = mx.nd.dot(x_nd, out_x_nd, transpose_a=True).asnumpy()
+ expected_grad_nd = w_x_backward_grad.reshape(w_shape)
+
+ x = mx.sym.Variable('x', stype='csr')
+ w = mx.sym.Variable('w')
+
+ w_x = mx.sym.reshape(w, w_x_shape, name="w_x")
+ out = mx.sym.sparse.dot(x, w_x, name='out_x')
+
+ grad_w_nd = rand_ndarray(w_shape, 'row_sparse')
+ executor = out.bind(ctx=ctx, args={"x": x_nd, "w": w_nd},
+ args_grad={"w": grad_w_nd})
+ executor.forward(is_train=True)
+ executor.backward(out_x_nd)
+
+ assert_almost_equal(grad_w_nd.asnumpy(), expected_grad_nd)
+
if __name__ == '__main__':
import nose
nose.runmodule()
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services