MarisaKirisame commented on a change in pull request #6767:
URL: https://github.com/apache/incubator-tvm/pull/6767#discussion_r512425339
##########
File path: python/tvm/relay/op/_tensor_grad.py
##########
@@ -665,3 +673,115 @@ def cross_entropy_with_logits_grad(orig, grad):
batch_size = take(shape, const(0, dtype="int32"), axis=0)
grad = grad / batch_size.astype(x.checked_type.dtype)
return [-grad * y, -grad * x]
+
+
+@register_gradient("take")
+def take_grad(orig, grad):
Review comment:
you can get by by defining a 'put' operator, that put a scalar into an
index of a tensor, and leave other palces unchanged. put and take has some
classic property which I assume will be better for the optimizer. It also allow
other optimization (e.g. put and reduce_sum, using grad + (put vala at idxa in
0_array) + (put valb at idxb in 0_array) will be collapsed into a long chain of
put on grad, allowing COW to kick in and all take grad mutation update (instead
of creating another tensor).
##########
File path: python/tvm/relay/op/_tensor_grad.py
##########
@@ -665,3 +673,115 @@ def cross_entropy_with_logits_grad(orig, grad):
batch_size = take(shape, const(0, dtype="int32"), axis=0)
grad = grad / batch_size.astype(x.checked_type.dtype)
return [-grad * y, -grad * x]
+
+
+@register_gradient("take")
+def take_grad(orig, grad):
Review comment:
@jroesch please look and comment as well.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]