larroy commented on a change in pull request #15120: [bug] fix higher grad log 
URL: https://github.com/apache/incubator-mxnet/pull/15120#discussion_r290873336
 
 

 ##########
 File path: tests/python/unittest/test_higher_order_grad.py
 ##########
 @@ -27,52 +27,79 @@ def test_log():
     def log(x):
         return nd.log(x)
 
+    def grad_op(x):
+        return 1/x
+
     def grad_grad_op(x):
         return -1/(x**2)
 
     arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5))
 
     for array in arrays:
-        check_second_order_unary(array, log, grad_grad_op)
+        check_second_order_unary(array, log, grad_op, grad_grad_op)
 
 
 @with_seed()
 def test_log2():
     def log2(x):
         return nd.log2(x)
 
+    def grad_op(x):
+        return 1/(x * math.log(2))
+
     def grad_grad_op(x):
         return -1/((x**2) * math.log(2))
 
     arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5))
 
     for array in arrays:
-        check_second_order_unary(array, log2, grad_grad_op)
+        check_second_order_unary(array, log2, grad_op, grad_grad_op)
 
 
 @with_seed()
 def test_log10():
     def log10(x):
         return nd.log10(x)
 
+    def grad_op(x):
+        return 1/(x * math.log(10))
+
     def grad_grad_op(x):
         return -1/((x**2) * math.log(10))
 
     arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5))
 
     for array in arrays:
-        check_second_order_unary(array, log10, grad_grad_op)
+        check_second_order_unary(array, log10, grad_op, grad_grad_op)
 
 
-def check_second_order_unary(x, op, grad_grad_op):
+def check_second_order_unary(x, op, grad_op, grad_grad_op):
     x = nd.array(x)
-    expect_grad_grad = grad_grad_op(x)
+    grad_x = grad_op(x)
+    grad_grad_x = grad_grad_op(x)
     x.attach_grad()
+
+    # Manual head_grads.
+    head_grads = nd.random.normal(shape=x.shape)
+    head_grad_grads = nd.random.normal(shape=x.shape)
+    head_grads.attach_grad()
+
+    # Perform compute.
     with autograd.record():
         y = op(x)
-        y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0]
-    y_grad.backward()
-    assert_almost_equal(expect_grad_grad.asnumpy(), x.grad.asnumpy())
+        y_grad = autograd.grad(y, x, head_grads=head_grads,
+                               create_graph=True, retain_graph=True)[0]
+
+    y_grad.backward(head_grad_grads)
+
+    # Compute expected values.
+    expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
+        head_grads.asnumpy()
+    expected_heads_grad = grad_x.asnumpy()
+
+    # Validate the gradients.
+    assert_almost_equal(expected_grad_grad, x.grad.asnumpy())
+    assert_almost_equal(expected_heads_grad, head_grads.grad.asnumpy())
 
 Review comment:
   y_grad.backward(head_grad_grads) indicate that head_grad_grads are the head 
gradients, it will update all the input variables which have attached gradient, 
in this case head_grad_grads is not an input to the graph, so your problem that 
the grad doesn't get updated is expected:
   
   
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/ndarray/ndarray.py#L2188
   
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/autograd.py#L270

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to