This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new bf11516632 [Unity][Training] More Relax operators gradient supported
(#14777)
bf11516632 is described below
commit bf11516632e2355a15b208fbf4aa75f1fda37959
Author: Chaofan Lin <[email protected]>
AuthorDate: Sun May 7 22:33:28 2023 +0800
[Unity][Training] More Relax operators gradient supported (#14777)
---
python/tvm/relax/op/_op_gradient.py | 125 ++++++++++++++++++++++++-
python/tvm/relax/op/binary.py | 5 +-
tests/python/relax/test_op_gradient_numeric.py | 27 ++++++
3 files changed, 151 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relax/op/_op_gradient.py
b/python/tvm/relax/op/_op_gradient.py
index b0e37a9418..36aed832e7 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -29,7 +29,8 @@ from ..expr import Call, Var, Expr, ShapeExpr
from ...tir import PrimExpr
from .base import register_gradient
-from .binary import less
+from .binary import less, greater_equal
+from .create import triu
from .datatype import astype
from .grad import (
no_grad,
@@ -54,7 +55,7 @@ from .manipulate import (
from .nn import conv2d_transpose, conv2d
from .search import where
from .statistical import sum, cumsum
-from .unary import cos, exp, log, sin, sqrt
+from .unary import cos, exp, log, sin, sqrt, sigmoid
# TODO(yixin, chaofan): handle symbolic shape for most of the gradients
@@ -231,7 +232,58 @@ def power_grad(
]
+@register_gradient("relax.maximum")
+def maximum_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of maximum.
+
+ Forward Form:
+ `z = relax.maximum(x, y)`
+
+ Backward:
+ Returns `[z_grad * (where(x < y, 0, 1)), z_grad * (where(x >= y, 0,
1))]`.
+ """
+ x = orig_call.args[0]
+ y = orig_call.args[1]
+ one = relax.const(1, _get_dtype(x))
+ zero = relax.const(0, _get_dtype(x))
+ return [
+ where(less(x, y), zero, one) * output_grad,
+ where(greater_equal(x, y), zero, one) * output_grad,
+ ]
+
+
+@register_gradient("relax.minimum")
+def minimum_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of minimum.
+
+ Forward Form:
+ `z = relax.minimum(x, y)`
+
+ Backward:
+ Returns `[z_grad * (where(x >= y, 0, 1)), z_grad * (where(x < y, 0,
1))]`.
+ """
+ x = orig_call.args[0]
+ y = orig_call.args[1]
+ one = relax.const(1, _get_dtype(x))
+ zero = relax.const(0, _get_dtype(x))
+ return [
+ where(greater_equal(x, y), zero, one) * output_grad,
+ where(less(x, y), zero, one) * output_grad,
+ ]
+
+
##################### Binary Comparison #####################
+
# For comparison operators, the gradients are no_grad
@@ -296,7 +348,8 @@ def not_equal_grad(
##################### Create #####################
-# For create operators, the gradients are no_grad.
+
+# For zeros/ones/full operators, the gradients are no_grad.
@register_gradient("relax.zeros_like")
@@ -359,6 +412,28 @@ def full_grad(
return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
+# Other create gradients operators
+
+
+@register_gradient("relax.triu")
+def triu_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of triu.
+
+ Forward Form:
+ `y = relax.triu(x, k)`
+
+ Backward:
+ Returns `[triu(y_grad, k)]`.
+ """
+ k = orig_call.attrs.k
+ return [triu(output_grad, k)]
+
+
##################### Unary #####################
@@ -770,6 +845,29 @@ def cumsum_grad(
return [grad]
+@register_gradient("relax.broadcast_to")
+def broadcast_to_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of broadcast_to.
+
+ Forward Form:
+ `y = relax.broadcast_to(x, new_shape)`
+
+ Backward:
+ Returns `[collapse_sum_to(y_grad, x.shape), no_grad]`.
+
+ The second parameter, the target ShapeExpr, is not differentiable.
+ """
+ return [
+ collapse_sum_to(output_grad, _get_shape(orig_call.args[0])),
+ no_grad(orig_call.args[1]),
+ ]
+
+
##################### Index #####################
@@ -938,6 +1036,27 @@ def relu_grad(
return [where(less(x, zero), zero, one) * output_grad]
+@register_gradient("relax.nn.silu")
+def silu_grad(
+ orig_var: Var,
+ orig_call: Call,
+ output_grad: Var,
+ ctx: BlockBuilder,
+) -> List[Expr]:
+ """Gradient of silu.
+
+ Forward Form:
+ `y = relax.silu(x)`
+
+ Backward:
+ Returns `[y_grad * (sigmoid(x) + y * (1 - sigmoid(x)))]`.
+ """
+ x = orig_call.args[0]
+ sig = sigmoid(x)
+ one = relax.const(1, _get_dtype(x))
+ return [output_grad * (sig + orig_var * (one - sig))]
+
+
@register_gradient("relax.nn.softmax")
def softmax_grad(
orig_var: Var,
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index 09a0c30f19..e3664fd8d8 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -250,11 +250,9 @@ def not_equal(x1: Expr, x2: Expr) -> Expr:
return _ffi_api.not_equal(x1, x2) # type: ignore
-###################### Comparison operators ######################
-
-
def maximum(x1: Expr, x2: Expr) -> Expr:
"""Element-wise maximum
+
Parameters
----------
x1 : relax.Expr
@@ -272,6 +270,7 @@ def maximum(x1: Expr, x2: Expr) -> Expr:
def minimum(x1: Expr, x2: Expr) -> Expr:
"""Element-wise minimum
+
Parameters
----------
x1 : relax.Expr
diff --git a/tests/python/relax/test_op_gradient_numeric.py
b/tests/python/relax/test_op_gradient_numeric.py
index cf2ff777d2..49b7daf96b 100644
--- a/tests/python/relax/test_op_gradient_numeric.py
+++ b/tests/python/relax/test_op_gradient_numeric.py
@@ -233,6 +233,8 @@ def test_unary(target, dev, unary_op_func, can_be_neg):
(relax.op.multiply,),
(relax.op.divide,),
(relax.op.power,),
+ (relax.op.maximum,),
+ (relax.op.minimum,),
)
@@ -299,6 +301,12 @@ def test_ones_zeros(target, dev, create_op_func):
)
[email protected]_targets("llvm")
+def test_triu(target, dev):
+ data_numpy = np.random.uniform(-1, 1, (3, 3)).astype(np.float32)
+ relax_check_gradients(relax.op.triu, [data_numpy], target, dev, k=0)
+
+
##################### Statistical #####################
@@ -483,6 +491,19 @@ def test_expand_dims_list(target, dev):
relax_check_gradients(relax.op.expand_dims, [data_numpy], target, dev,
axis=(0, 2, 3))
[email protected]_targets("llvm")
+def test_broadcast_to(target, dev):
+ data_numpy = np.random.randint(1, 16, (3, 4)).astype(np.float32)
+ relax_check_gradients(
+ relax.op.broadcast_to,
+ [data_numpy],
+ target,
+ dev,
+ shape=(2, 3, 4),
+ ignore_grads=[1],
+ )
+
+
##################### Index #####################
@@ -594,6 +615,12 @@ def test_relu(target, dev):
relax_check_gradients(relax.op.nn.relu, [data1_numpy], target, dev)
[email protected]_targets("llvm")
+def test_silu(target, dev):
+ data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)
+ relax_check_gradients(relax.op.nn.silu, [data1_numpy], target, dev)
+
+
@tvm.testing.parametrize_targets("llvm")
def test_softmax(target, dev):
data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)