kshitij12345 commented on issue #17325: Fix Flaky Test Higher Order Grad URL: https://github.com/apache/incubator-mxnet/pull/17325#issuecomment-575979165 @sxjscience @apeforest @larroy While trying to track the causes for failures, I found that there is an issue with the way we are computing the second order gradients. Sample : https://github.com/apache/incubator-mxnet/blob/49df604ac3bb156e446671403c86a3535af9615d/src/operator/tensor/elemwise_unary_op_trig.cc#L201-L207 While computing `x_grad`, elements of `dydx` can be `0`, which leads to `Nan` at that position in our computation, instead of it being `0`. Most of the failures that I have checked in CI seem to be of this case. Example : #17362 where we expect 0 but end up with `Nan` ``` Items are not equal: Error nan exceeds tolerance rtol=1.000000e-05, atol=1.000000e-20 (mismatch 0.154321%). Location of maximum error: (1, 6, 2, 1), a=0.00000000, b=nan ACTUAL: array([[[[-0.03453423, 0.1168007 , -0.00583075], [ 0.02170432, 0.37017354, 0.09864384], [-0.02544087, -0.06709251, -0.34824234]],... DESIRED: array([[[[-0.03453423, 0.1168007 , -0.00583075], [ 0.02170432, 0.37017348, 0.09864385], [-0.02544087, -0.0670925 , -0.34824237]],... -------------------- >> begin captured stdout << --------------------- *** Maximum errors for vector of size 648: rtol=1e-05, atol=1e-20 1: Error nan Location of error: (1, 6, 2, 1), a=0.00000000, b=nan --------------------- >> end captured stdout << ---------------------- -------------------- >> begin captured logging << -------------------- common: INFO: Setting test np/mx/python random seeds, use MXNET_TEST_SEED=402611671 to reproduce. --------------------- >> end captured logging << --------------------- ``` --------------------------------------------------------- **Mitigation** : The reason we divide by `dydx` is to get first gradient from previous computation instead of computing it again. One solution might be to replace ``` auto x_grad = op.div(dydx_mul_grad_x, dydx); ``` with actually computing the first degree gradient ``` auth x_grad = nnvm::NodeEntry{mxnet::op::MakeNode("_backward_arcsin", dependent_node->attrs.name + "_mul_scalar", {x}, &nullptr, &dependent_node)}; ``` Will try this and update how this goes. Let me know if you have any other idea.
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services
