This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new bf11516632 [Unity][Training] More Relax operators gradient supported 
(#14777)
bf11516632 is described below

commit bf11516632e2355a15b208fbf4aa75f1fda37959
Author: Chaofan Lin <[email protected]>
AuthorDate: Sun May 7 22:33:28 2023 +0800

    [Unity][Training] More Relax operators gradient supported (#14777)
---
 python/tvm/relax/op/_op_gradient.py            | 125 ++++++++++++++++++++++++-
 python/tvm/relax/op/binary.py                  |   5 +-
 tests/python/relax/test_op_gradient_numeric.py |  27 ++++++
 3 files changed, 151 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/op/_op_gradient.py 
b/python/tvm/relax/op/_op_gradient.py
index b0e37a9418..36aed832e7 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -29,7 +29,8 @@ from ..expr import Call, Var, Expr, ShapeExpr
 from ...tir import PrimExpr
 
 from .base import register_gradient
-from .binary import less
+from .binary import less, greater_equal
+from .create import triu
 from .datatype import astype
 from .grad import (
     no_grad,
@@ -54,7 +55,7 @@ from .manipulate import (
 from .nn import conv2d_transpose, conv2d
 from .search import where
 from .statistical import sum, cumsum
-from .unary import cos, exp, log, sin, sqrt
+from .unary import cos, exp, log, sin, sqrt, sigmoid
 
 
 # TODO(yixin, chaofan): handle symbolic shape for most of the gradients
@@ -231,7 +232,58 @@ def power_grad(
     ]
 
 
+@register_gradient("relax.maximum")
+def maximum_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of maximum.
+
+    Forward Form:
+        `z = relax.maximum(x, y)`
+
+    Backward:
+        Returns `[z_grad * (where(x < y, 0, 1)), z_grad * (where(x >= y, 0, 
1))]`.
+    """
+    x = orig_call.args[0]
+    y = orig_call.args[1]
+    one = relax.const(1, _get_dtype(x))
+    zero = relax.const(0, _get_dtype(x))
+    return [
+        where(less(x, y), zero, one) * output_grad,
+        where(greater_equal(x, y), zero, one) * output_grad,
+    ]
+
+
+@register_gradient("relax.minimum")
+def minimum_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of minimum.
+
+    Forward Form:
+        `z = relax.minimum(x, y)`
+
+    Backward:
+        Returns `[z_grad * (where(x >= y, 0, 1)), z_grad * (where(x < y, 0, 
1))]`.
+    """
+    x = orig_call.args[0]
+    y = orig_call.args[1]
+    one = relax.const(1, _get_dtype(x))
+    zero = relax.const(0, _get_dtype(x))
+    return [
+        where(greater_equal(x, y), zero, one) * output_grad,
+        where(less(x, y), zero, one) * output_grad,
+    ]
+
+
 ##################### Binary Comparison #####################
+
 # For comparison operators, the gradients are no_grad
 
 
@@ -296,7 +348,8 @@ def not_equal_grad(
 
 
 ##################### Create #####################
-# For create operators, the gradients are no_grad.
+
+# For zeros/ones/full operators, the gradients are no_grad.
 
 
 @register_gradient("relax.zeros_like")
@@ -359,6 +412,28 @@ def full_grad(
     return [no_grad(orig_call.args[0]), no_grad(orig_call.args[1])]
 
 
+# Other create gradients operators
+
+
+@register_gradient("relax.triu")
+def triu_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of triu.
+
+    Forward Form:
+        `y = relax.triu(x, k)`
+
+    Backward:
+        Returns `[triu(y_grad, k)]`.
+    """
+    k = orig_call.attrs.k
+    return [triu(output_grad, k)]
+
+
 ##################### Unary #####################
 
 
@@ -770,6 +845,29 @@ def cumsum_grad(
     return [grad]
 
 
+@register_gradient("relax.broadcast_to")
+def broadcast_to_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of broadcast_to.
+
+    Forward Form:
+        `y = relax.broadcast_to(x, new_shape)`
+
+    Backward:
+        Returns `[collapse_sum_to(y_grad, x.shape), no_grad]`.
+
+        The second parameter, the target ShapeExpr, is not differentiable.
+    """
+    return [
+        collapse_sum_to(output_grad, _get_shape(orig_call.args[0])),
+        no_grad(orig_call.args[1]),
+    ]
+
+
 ##################### Index #####################
 
 
@@ -938,6 +1036,27 @@ def relu_grad(
     return [where(less(x, zero), zero, one) * output_grad]
 
 
+@register_gradient("relax.nn.silu")
+def silu_grad(
+    orig_var: Var,
+    orig_call: Call,
+    output_grad: Var,
+    ctx: BlockBuilder,
+) -> List[Expr]:
+    """Gradient of silu.
+
+    Forward Form:
+        `y = relax.silu(x)`
+
+    Backward:
+        Returns `[y_grad * (sigmoid(x) + y * (1 - sigmoid(x)))]`.
+    """
+    x = orig_call.args[0]
+    sig = sigmoid(x)
+    one = relax.const(1, _get_dtype(x))
+    return [output_grad * (sig + orig_var * (one - sig))]
+
+
 @register_gradient("relax.nn.softmax")
 def softmax_grad(
     orig_var: Var,
diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py
index 09a0c30f19..e3664fd8d8 100644
--- a/python/tvm/relax/op/binary.py
+++ b/python/tvm/relax/op/binary.py
@@ -250,11 +250,9 @@ def not_equal(x1: Expr, x2: Expr) -> Expr:
     return _ffi_api.not_equal(x1, x2)  # type: ignore
 
 
-###################### Comparison operators ######################
-
-
 def maximum(x1: Expr, x2: Expr) -> Expr:
     """Element-wise maximum
+
     Parameters
     ----------
     x1 : relax.Expr
@@ -272,6 +270,7 @@ def maximum(x1: Expr, x2: Expr) -> Expr:
 
 def minimum(x1: Expr, x2: Expr) -> Expr:
     """Element-wise minimum
+
     Parameters
     ----------
     x1 : relax.Expr
diff --git a/tests/python/relax/test_op_gradient_numeric.py 
b/tests/python/relax/test_op_gradient_numeric.py
index cf2ff777d2..49b7daf96b 100644
--- a/tests/python/relax/test_op_gradient_numeric.py
+++ b/tests/python/relax/test_op_gradient_numeric.py
@@ -233,6 +233,8 @@ def test_unary(target, dev, unary_op_func, can_be_neg):
     (relax.op.multiply,),
     (relax.op.divide,),
     (relax.op.power,),
+    (relax.op.maximum,),
+    (relax.op.minimum,),
 )
 
 
@@ -299,6 +301,12 @@ def test_ones_zeros(target, dev, create_op_func):
     )
 
 
[email protected]_targets("llvm")
+def test_triu(target, dev):
+    data_numpy = np.random.uniform(-1, 1, (3, 3)).astype(np.float32)
+    relax_check_gradients(relax.op.triu, [data_numpy], target, dev, k=0)
+
+
 ##################### Statistical #####################
 
 
@@ -483,6 +491,19 @@ def test_expand_dims_list(target, dev):
     relax_check_gradients(relax.op.expand_dims, [data_numpy], target, dev, 
axis=(0, 2, 3))
 
 
[email protected]_targets("llvm")
+def test_broadcast_to(target, dev):
+    data_numpy = np.random.randint(1, 16, (3, 4)).astype(np.float32)
+    relax_check_gradients(
+        relax.op.broadcast_to,
+        [data_numpy],
+        target,
+        dev,
+        shape=(2, 3, 4),
+        ignore_grads=[1],
+    )
+
+
 ##################### Index #####################
 
 
@@ -594,6 +615,12 @@ def test_relu(target, dev):
     relax_check_gradients(relax.op.nn.relu, [data1_numpy], target, dev)
 
 
[email protected]_targets("llvm")
+def test_silu(target, dev):
+    data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)
+    relax_check_gradients(relax.op.nn.silu, [data1_numpy], target, dev)
+
+
 @tvm.testing.parametrize_targets("llvm")
 def test_softmax(target, dev):
     data1_numpy = np.random.randint(0, 16, (3, 3)).astype(np.float32)

Reply via email to