eric-haibin-lin commented on a change in pull request #10371: [MXNET-263] [WIP]
Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU
URL: https://github.com/apache/incubator-mxnet/pull/10371#discussion_r182914225
##########
File path: tests/python/unittest/test_sparse_operator.py
##########
@@ -1209,6 +1209,27 @@ def check_cast_storage(shape, density, from_stype,
to_stype, check_numeric_grad=
@with_seed()
def test_sparse_dot():
+ def test_infer_forward_stype(lhs_shape, rhs_shape, lhs_density,
rhs_density, trans_a, trans_b):
+ all_stypes = ["default", "csr", "row_sparse"]
+ lhs_nd = rand_ndarray(lhs_shape, 'default', density=lhs_density)
+ rhs_nd = rand_ndarray(rhs_shape, 'default', density=rhs_density)
+ out_nd = mx.nd.dot(lhs_nd, rhs_nd, transpose_a=trans_a,
transpose_b=trans_b)
+ out_np = out_nd.asnumpy()
+ for lhs_stype in all_stypes:
+ for rhs_stype in all_stypes:
+ for forward_stype in all_stypes:
+ lhs = lhs_nd.tostype(lhs_stype)
+ rhs = rhs_nd.tostype(rhs_stype)
+ out = mx.nd.dot(lhs, rhs, forward_stype_hint=forward_stype,
+ transpose_a=trans_a, transpose_b=trans_b)
+ assert_almost_equal(out.tostype('default').asnumpy(),
out_np, rtol=1e-4, atol=1e-5)
+ lhs_var = mx.symbol.Variable('lhs', stype=lhs_stype)
+ rhs_var = mx.symbol.Variable('rhs', stype=rhs_stype)
+ out = mx.symbol.sparse.dot(lhs_var, rhs_var,
+
forward_stype_hint=forward_stype,
+ transpose_a=trans_a,
transpose_b=trans_b)
+ location = {'lhs': lhs, 'rhs': rhs}
+ check_symbolic_forward(out, location, [out_np], rtol=1e-3,
atol=1e-4)
Review comment:
Also check_symbolic_backward?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services