barry-jin commented on a change in pull request #20500:
URL: https://github.com/apache/incubator-mxnet/pull/20500#discussion_r689863276



##########
File path: python/mxnet/ndarray/ndarray.py
##########
@@ -2885,6 +2885,11 @@ def attach_grad(self, grad_req='write', stype=None):
             ctypes.pointer(mx_uint(grad_req)),
             ctypes.pointer(grad.handle)))
 
+    def drop_grad(self):

Review comment:
       Could you also add drop_grad to MXNet NumPy array package. 

##########
File path: tests/python/unittest/test_autograd.py
##########
@@ -519,3 +519,33 @@ def test_gradient():
     dx.backward()
     assert abs(x.grad.asscalar() - 2.71828175) < 1e-7
 
+def test_retain_grad_drop_grad():
+    x = nd.array([1,2,3,4])
+    x.attach_grad()
+    y = nd.array([5,6,7,8])
+    y.attach_grad()
+
+    with mx.autograd.record():
+        u = x * y
+        z = u * x
+
+    u.attach_grad()
+    z.attach_grad()
+    out_grad = nd.array([10, 10, 10, 10])
+    z.backward(out_grad, retain_graph=True)
+    
+    assert (u.grad == out_grad * x).asnumpy().all()
+    assert (z.grad == out_grad).asnumpy().all()
+    assert (x.grad == out_grad * 2 * x * y).asnumpy().all()
+    assert (y.grad == out_grad * x*x).asnumpy().all()
+
+    u.drop_grad()
+    z.drop_grad()
+    y.drop_grad()
+    out_grad = nd.array([0.1, 0.1, 0.1, 0.1])
+    z.backward(out_grad)
+
+    assert u.grad is None
+    assert z.grad is None
+    assert (x.grad == out_grad * 2 * x * y).asnumpy().all()
+    assert y.grad is None

Review comment:
       Could you also add some tests to test retain grad on gluon blocks. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to