This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0c447d6f9c [Unity][Transform] High-level reverse-mode automatic 
differentiation pass (#14542)
0c447d6f9c is described below

commit 0c447d6f9c85cb1f2b3f246a1c58cb6a349bf600
Author: Chaofan Lin <[email protected]>
AuthorDate: Sun Apr 9 20:13:21 2023 +0800

    [Unity][Transform] High-level reverse-mode automatic differentiation pass 
(#14542)
    
    ### Introduction
    
    This PR introduces the high-level reverse-mode automatic differentiation 
pass  `Gradient` for Relax. It's the core component when we are trying training 
or fine-tuning in Relax IR.
    
    Before upstreaming, this work is actively iterated and maintained in many 
forks like [mlc](https://github.com/mlc-ai/relax) and 
[relax-training](https://github.com/ACMClass-TVM-20/relax-training). Now it 
reaches a relatively stable version and it's time for us to upstream this 
important work to the unity branch.
    
    The Python side API:
    - `Gradient(func_name: str, require_grads: Optional[Union[Var, List[Var]]] 
= None, target_index: int = 0) -> tvm.ir.transform.Pass`
    
    It will transform the given funcion in the IRModule, and adds a new 
function that calculates the gradient with regard to the function's output.
    
    ### Examples
    
    ```
    @I.ir_module
    class Module:
        @R.function
        def main(
            x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), 
dtype="float32")
        ) -> R.Tensor((), dtype="float32"):
            with R.dataflow():
                lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
                # use R.sum to reduce the tensor to a scalar
                lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, 
keepdims=False)
                R.output(lv2)
            return lv2
    
    After = relax.transform.Gradient("main")(Module)
    ```
    
    Then the transformed module `After` will be
    
    ```
    @I.ir_module
    class After:
        @R.function
        def main(
            x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), 
dtype="float32")
        ) -> R.Tensor((), dtype="float32"):
            with R.dataflow():
                lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
                lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, 
keepdims=False)
                R.output(lv2)
            return lv2
    
        @R.function
        def main_adjoint(
            x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), 
dtype="float32")
        ) -> R.Tuple(
            R.Tensor((), dtype="float32"),
            R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")),
        ):
            with R.dataflow():
                # original bindings
                lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
                lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, 
keepdims=False)
                # bindings w.r.t. intermediate variables
                lv2_adjoint: R.Tensor((), dtype="float32") = R.ones((), 
dtype="float32")
                lv1_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(
                    lv2_adjoint, (3, 3)
                )
                # bindings w.r.t. parameters
                x_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint
                y_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint
                R.output(lv2, x_adjoint, y_adjoint)
            # return value: (orig_return_values, tuple(adjoints))
            return (lv2, (x_adjoint, y_adjoint))
    ```
    
    Here we specify the target function `main` by its name.
    We let the `require_grads` be default value (`None`) so it will calculate 
all inputs' adjoints (`x_adjoint`, `y_adjoint`) and return them.
    We let the `target_index` be default value `0` so it will take the unique 
return value `lv2` as the target (the scalar we start to differentiate and 
propagating adjoints) of AD.
    
    
    
    Co-authored-by: Yixin Dong <[email protected]>
---
 include/tvm/relax/transform.h                      |   26 +
 python/tvm/relax/transform/transform.py            |  170 +++
 src/relax/transform/gradient.cc                    |  469 ++++++++
 src/relax/transform/utils.cc                       |   13 +
 src/relax/transform/utils.h                        |   16 +
 tests/python/relax/test_transform_gradient.py      | 1164 ++++++++++++++++++++
 .../relax/test_transform_gradient_numeric.py       |  192 ++++
 7 files changed, 2050 insertions(+)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 5615a0951f..fec2ef0a04 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -348,6 +348,32 @@ class PatternCheckContext : public ObjectRef {
                                             PatternCheckContextNode);
 };
 
+/*!
+ * \brief Reverse-mode automatic differentiation.
+ *
+ * This pass will differentiate one function in the IRModule. Now the input 
function must have only
+ * one dataflow block.
+ *
+ * For a given function specified by `func_name`, it generates a new function 
with the name
+ * `func_name + "_adjoint"`. The new function computes the gradient of the 
**differentiation
+ * target** with respect to the arguments specified by `require_grads` of the 
original function.
+ *
+ * If the function has only one return value, the return value will be 
specified as target. If the
+ * function has more than one return values, the target will be specified as 
the target_index-th
+ * return value. The target must be a scalar (0-dim tensor).
+ *
+ * \param func_name The name of the specified function.
+ * \param require_grads The relax variables whose adjoints is needed. Must be 
parameters of the
+ * given function and should not be duplicate. If it is not specified, 
adjoints of all parameters
+ * would be computed.
+ * \param target_index If the specified function has more than one return 
values, specify the index
+ * of the return value as the target. If it is not specified, the first return 
value will be the
+ * target.
+ * \return The Pass.
+ */
+TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads = 
NullOpt,
+                      int target_index = 0);
+
 /*!
  * \brief Apply pattern matching to each function in the given module, and 
group matched
  * expressions into a new function. The end result is similar to FuseOps, but 
fusion is driven
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 83f2f1bba8..a53c45b655 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -46,6 +46,176 @@ class DataflowBlockPass(tvm.ir.transform.Pass):
     """A pass that works on each tvm.relax.DataflowBlock in a module."""
 
 
+def Gradient(
+    func_name: str, require_grads: Optional[Union[Var, List[Var]]] = None, 
target_index: int = 0
+) -> tvm.ir.transform.Pass:
+    """Reverse-mode automatic differentiation.
+
+    This pass will differentiate one function in the IRModule. Now the input 
function must have only
+    one dataflow block.
+
+    For a given function specified by `func_name`, it generates a new function 
with the name
+    `func_name + "_adjoint"`. The new function computes the gradient of the 
**differentiation
+    target** with respect to the arguments specified by `require_grads` of the 
original function.
+
+    If the function has only one return value, the return value will be 
specified as target. If the
+    function has more than one return values, the target will be specified as 
the target_index-th
+    return value. The target must be a scalar (0-dim tensor).
+
+    The new function will be like:
+
+    .. code-block:: python
+        @R.function
+        def main_adjoint(original_parameters):
+            with R.dataflow():
+                # the bindings of the original function
+                ...
+                # calculating the gradients
+                ...
+                R.output(original_outputs, grad_1, grad_2, ...)
+            return (original_return_value, (grad_1, grad_2, ...))
+
+    Parameters
+    ----------
+    func_name : str
+        The name of the specific function.
+
+    require_grads : Optional[Union[relax.Var, List[relax.Var]]]
+        The relax variables whose adjoints is needed. Must be parameters of 
the given function and
+        should not be duplicate. If it is not specified, adjoints of all 
parameters would be
+        computed.
+
+    target_index : int
+        If the specified function has more than one return values, specify the 
index of the return
+        value as the target. If it is not specified, the first return value 
will be the target.
+
+    Returns
+    -------
+    ret : tvm.ir.transform.Pass
+        The Pass.
+
+    Examples
+    --------
+    The following code shows how to use this pass:
+
+    .. code-block:: python
+
+        @I.ir_module
+        class Module:
+            @R.function
+            def main(
+                x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), 
dtype="float32")
+            ) -> R.Tensor((), dtype="float32"):
+                with R.dataflow():
+                    lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
+                    # use R.sum to reduce the tensor to a scalar
+                    lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, 
keepdims=False)
+                    R.output(lv2)
+                return lv2
+
+        After = relax.transform.Gradient("main")(Module)
+
+    The module after the Gradient pass will be:
+
+    .. code-block:: python
+
+        @I.ir_module
+        class After:
+            @R.function
+            def main(
+                x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), 
dtype="float32")
+            ) -> R.Tensor((), dtype="float32"):
+                with R.dataflow():
+                    lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
+                    lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, 
keepdims=False)
+                    R.output(lv2)
+                return lv2
+
+            @R.function
+            def main_adjoint(
+                x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), 
dtype="float32")
+            ) -> R.Tuple(
+                R.Tensor((), dtype="float32"),
+                R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")),
+            ):
+                with R.dataflow():
+                    # original bindings
+                    lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
+                    lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, 
keepdims=False)
+                    # bindings w.r.t. intermediate variables
+                    lv2_adjoint: R.Tensor((), dtype="float32") = R.ones((), 
dtype="float32")
+                    lv1_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.broadcast_to(
+                        lv2_adjoint, (3, 3)
+                    )
+                    # bindings w.r.t. parameters
+                    x_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint
+                    y_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint
+                    R.output(lv2, x_adjoint, y_adjoint)
+                # return value: (orig_return_values, tuple(adjoints))
+                return (lv2, (x_adjoint, y_adjoint))
+
+    The second example is returning multiple values and specifying the target 
with `target_index`:
+
+    .. code-block:: python
+        @I.ir_module
+        class Module:
+            @R.function
+            def main(
+                x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), 
dtype="float32")
+            ) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), 
dtype="float32")):
+                with R.dataflow():
+                    lv1: R.Tensor((), dtype="float32") = R.sum(x, axis=None, 
keepdims=False)
+                    lv2: R.Tensor((), dtype="float32") = R.sum(y, axis=None, 
keepdims=False)
+                    R.output(lv1, lv2)
+                return (lv1, lv2)
+
+        After = relax.transform.Gradient("main", target_index=1)(Module)
+
+    The module after the Gradient pass will be:
+
+    .. code-block:: python
+        @I.ir_module
+        class Module:
+            @R.function
+            def main(
+                x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), 
dtype="float32")
+            ) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), 
dtype="float32")):
+                with R.dataflow():
+                    lv1: R.Tensor((), dtype="float32") = R.sum(x, axis=None, 
keepdims=False)
+                    lv2: R.Tensor((), dtype="float32") = R.sum(y, axis=None, 
keepdims=False)
+                    R.output(lv1, lv2)
+                return (lv1, lv2)
+
+            @R.function
+            def main_adjoint(
+                x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), 
dtype="float32")
+            ) -> R.Tuple(
+                R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), 
dtype="float32")),
+                R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32")),
+            ):
+                with R.dataflow():
+                    # original bindings
+                    lv1: R.Tensor((), dtype="float32") = R.sum(x, axis=None, 
keepdims=False)
+                    lv2: R.Tensor((), dtype="float32") = R.sum(y, axis=None, 
keepdims=False)
+                    # bindings w.r.t. intermediate variables
+                    # gradient of intermediate variables that is not related 
to the target will not
+                    # be calculated
+                    lv2_adjoint: R.Tensor((), dtype="float32") = R.ones((), 
dtype="float32")
+                    # bindings w.r.t. parameters
+                    x_adjoint: R.Tensor((3, 3), dtype="float32") = R.zeros((3, 
3), dtype="float32")
+                    y_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.broadcast_to(
+                        lv2_adjoint, (3, 3)
+                    )
+                    R.output(lv1, lv2, x_adjoint, y_adjoint)
+                # return value: (orig_return_values, tuple(adjoints))
+                return ((lv1, lv2), (x_adjoint, y_adjoint))
+    """
+    if require_grads is not None and not isinstance(require_grads, list):
+        require_grads = [require_grads]
+
+    return _ffi_api.Gradient(func_name, require_grads, target_index)  # type: 
ignore
+
+
 def ToNonDataflow() -> tvm.ir.transform.Pass:
     """Transform all dataflow structure to non-dataflow version.
 
diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
new file mode 100644
index 0000000000..e7bdea6036
--- /dev/null
+++ b/src/relax/transform/gradient.cc
@@ -0,0 +1,469 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/gradient.cc
+ * \brief Reverse-mode automatic differentiation.
+ *
+ * Now only supports differentiating one function in the IRModule with one 
dataflow block
+ * with respect to the only return value of the function, which needs to be 
scalar.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/nested_msg.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/relax/transform.h>
+
+#include <unordered_set>
+
+#include "../op/tensor/binary.h"
+#include "../op/tensor/create.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+using AdjointMsg = NestedMsg<Expr>;
+
+// A tool class for GradientMutator
+// Visit the forward bindings and generate the backward bindings
+class BackwardBindingGenerator : private ExprVisitor {
+ public:
+  /*!
+   * \brief Generate the backward bindings for the corresponding 
GradientMutator
+   *
+   * \param builder The BlockBuilder of GradientMutator, used to generate 
bindings
+   * \param forward_block The forward DataflowBlock
+   * \param require_grads The Var list to differentiate w.r.t.
+   * \param target_var The target Var to differentiate
+   * \param orig_return_value The original return value of the function. The 
new return value is a
+   * 2-tuple, containing the original return value, and a tuple of the 
adjoints of parameters.
+   * \return The return expr of new adjoint function.
+   */
+  static Expr Generate(const BlockBuilder& builder, const DataflowBlock& 
forward_block,
+                       const Array<Var>& require_grads, const Var& target_var,
+                       const Expr& orig_return_value) {
+    BackwardBindingGenerator generator(builder);
+
+    // Initialize the adjoint of target_var as ones op. We have already check 
the target.
+    auto* target_sinfo = GetStructInfoAs<TensorStructInfoNode>(target_var);
+    const Expr& target_adjoint = ones(target_sinfo->shape.value(), 
target_sinfo->dtype);
+    UpdateStructInfo(target_adjoint, GetRef<StructInfo>(target_sinfo));
+    generator.adjoint_msg_map_.Set(target_var, AdjointMsg(target_adjoint));
+
+    // We do reverse-mode ad, so visit bindings backwards
+    for (auto it = forward_block->bindings.rbegin(); it != 
forward_block->bindings.rend(); ++it) {
+      generator.VisitBinding(*it);
+    }
+
+    return generator.Epilogue(require_grads, orig_return_value);
+  }
+
+ private:
+  explicit BackwardBindingGenerator(const BlockBuilder& builder) : 
builder_(builder) {}
+
+  void VisitBinding(const Binding& binding) final {
+    // TODO(chaofan, yixin): support other types of bindings
+    CHECK(binding->IsInstance<VarBindingNode>()) << "now only support 
VarBindingNode";
+    auto* var_binding = binding.as<VarBindingNode>();
+
+    auto it = adjoint_msg_map_.find(var_binding->var);
+    if (it == adjoint_msg_map_.end()) {
+      // This var is not used in the following bindings
+      return;
+    }
+
+    // Meet the definition of binding->var
+    // Create the adjoint var and bind the adjoint value to it
+    EmitAdjoint(var_binding->var, (*it).second, true);
+
+    Expr value = var_binding->value;
+    // TODO(chaofan, yixin): support other types of binding values
+    CHECK(value->IsInstance<CallNode>() || value->IsInstance<TupleNode>() ||
+          value->IsInstance<TupleGetItemNode>() || 
value->IsInstance<VarNode>() ||
+          value->IsInstance<ConstantNode>())
+        << "now does not support the type of binding value: " << value;
+
+    ExprVisitor::VisitBinding_(var_binding);
+  }
+
+  // Handle the adjoint expr of the inputs of binding
+  // For call node, we would call the registered gradient functions
+  void VisitBinding_(const VarBindingNode* binding, const CallNode* call) 
final {
+    static const OpAttrMap<FPrimalGradient>& gradient_op_map =
+        Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
+
+    Var adjoint_var = adjoint_var_map_[binding->var];
+    const Op& call_op = Downcast<Op>(call->op);
+    const Array<Expr>& partials =
+        gradient_op_map[call_op](binding->var, GetRef<Call>(call), 
adjoint_var, builder_);
+    ICHECK(partials.size() == call->args.size()) << "partials number != inputs 
number";
+
+    for (size_t i = 0; i < partials.size(); ++i) {
+      Expr partial = partials[i];
+      if (IsCallNoGrad(partial)) {  // no grad: don't update
+        continue;
+      }
+      if (!partial->struct_info_.defined()) {
+        UpdateStructInfo(partial, GetStructInfo(call->args[i]));
+      }
+      UpdateAdjoint(call->args[i], partial);
+    }
+  }
+
+  // For Tuple nodes, we would iterate over the input tuple and update adjoint 
exprs for each input
+  // e.g.
+  // a = (b, c)
+  // b_adjoint += a_adjoint_var[0], c_adjoint += a_adjoint_var[1]
+  // a = ((b, c), d)
+  // b_adjoint += a_adjoint_var[0][0], c_adjoint += a_adjoint_var[0][1],
+  // d_adjoint += a_adjoint_var[1]
+  //
+  // Here we use adjoint_var to simplify calculation
+  void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple) 
final {
+    UpdateAdjoint(GetRef<Tuple>(tuple), adjoint_var_map_[binding->var]);
+  }
+
+  // For TupleGetItem nodes, we do a partial update
+  // e.g.
+  // b = a[0]
+  // a_adjoint[0] += b_adjoint_var
+  // If a_adjoint does not exist, we would create a zeros tuple as a_adjoint 
first, and then add
+  void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode* 
tuple_get_item) final {
+    ICHECK(tuple_get_item->tuple->IsInstance<VarNode>())
+        << "The tuple field of a TupleGetItem is not bound to a Var";
+    auto* tuple_sinfo = 
GetStructInfoAs<TupleStructInfoNode>(tuple_get_item->tuple);
+    ICHECK(tuple_sinfo) << "The tuple field of a TupleGetItem must has a 
TupleStructInfo";
+
+    const Var& tuple_var = Downcast<Var>(tuple_get_item->tuple);
+    if (adjoint_msg_map_.count(tuple_var) == 0) {
+      const AdjointMsg& init = 
InitZerosAdjointNested(GetRef<StructInfo>(tuple_sinfo));
+      adjoint_msg_map_.Set(tuple_var, init);
+    }
+
+    adjoint_msg_map_.Set(tuple_var,
+                         AddInAdjointMsg(adjoint_msg_map_[tuple_var], 
tuple_get_item->index,
+                                         
ExprToAdjointMsg(adjoint_var_map_[binding->var])));
+  }
+
+  // For assign nodes, we add the adjoint of output to the adjoint of input
+  void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode* 
var) final {
+    UpdateAdjoint(GetRef<Var>(var), adjoint_var_map_[binding->var]);
+  }
+
+  void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final {
+    UpdateAdjoint(GetRef<Var>(var), adjoint_var_map_[binding->var]);
+  }
+
+  // For constant nodes, we do not have to handle it because it does not 
contribute to the adjoint
+  void VisitBinding_(const VarBindingNode* binding, const ConstantNode* var) 
final { return; }
+
+  // Add partial (Expr type) to the adjoint of expr
+  void UpdateAdjoint(const Expr& expr, const Expr& partial) {
+    DecomposeNestedMsg(expr, ExprToAdjointMsg(partial), [&](Expr leaf, 
AdjointMsg msg) {
+      if (leaf->IsInstance<VarNode>()) {
+        const Var& v = Downcast<Var>(leaf);
+        if (adjoint_msg_map_.count(v) == 0) {
+          adjoint_msg_map_.Set(v, msg);
+        } else {
+          adjoint_msg_map_.Set(v, TupleAwareAdd(adjoint_msg_map_[v], msg));
+        }
+      } else if (leaf->IsInstance<ConstantNode>()) {
+        // nothing to do
+      } else if (leaf->IsInstance<ShapeExprNode>()) {
+        // must be no grad
+        ICHECK(IsCallNoGrad(partial));
+      } else {
+        LOG(FATAL) << "UpdateAdjoint: leaf type not supported. Currently Var 
and Constant leaves "
+                      "are supported.";
+      }
+    });
+  }
+
+  // Transform the adjoint expressed as NestedMsg<Expr> into adjoint Expr, and 
then emit it
+  // If the adjoint is assigned to a DataflowVar (the adjoint corresponds to a 
non-output binding),
+  // it would be stored in adjoint_var_map_ for future lookup
+  Var EmitAdjoint(const Var& source_var, const AdjointMsg& adjoint, bool 
is_dataflow_var) {
+    Var adjoint_var;
+    if (is_dataflow_var) {
+      adjoint_var = builder_->Emit(AdjointMsgToExpr(adjoint), 
source_var->name_hint() + "_adjoint");
+      adjoint_var_map_.Set(source_var, adjoint_var);
+    } else {
+      adjoint_var =
+          builder_->EmitOutput(AdjointMsgToExpr(adjoint), 
source_var->name_hint() + "_adjoint");
+    }
+    return adjoint_var;
+  }
+
+  // Handle the return value of the AD function.
+  // Returns the new return value, which would be like:
+  // Tuple(original_return_value,
+  //       Tuple(adjoint_of_require_grads_1, adjoint_of_require_grads_2, ...))
+  Expr Epilogue(const Array<Var>& require_grads, const Expr& 
orig_return_value) {
+    // create adjoint variables for inputs, and then bind adjoints
+    Array<Expr> out_adjoints;
+
+    for (Var var : require_grads) {
+      // If the var don't have adjoint msg, it do not contribute to the target
+      // so its adjoint is zeros
+      AdjointMsg adjoint =
+          
adjoint_msg_map_.Get(var).value_or(InitZerosAdjointNested(GetStructInfo(var)));
+      Var adjoint_var = EmitAdjoint(var, adjoint, false);
+      out_adjoints.push_back(adjoint_var);
+    }
+
+    return Tuple({orig_return_value, Tuple(out_adjoints)});
+  }
+
+  static bool IsCallZeros(const Expr& expr) {
+    return expr->IsInstance<CallNode>() && Downcast<Call>(expr)->op == 
Op::Get("relax.zeros");
+  }
+
+  static bool IsCallNoGrad(const Expr& expr) {
+    return expr->IsInstance<CallNode>() &&
+           Downcast<Call>(expr)->op == Op::Get("relax.grad.no_grad");
+  }
+
+  static Expr AdjointMsgToExpr(AdjointMsg msg) {
+    return NestedMsgToExpr<Expr>(msg, [](Optional<Expr> leaf_expr) {
+      if (!leaf_expr.defined()) {
+        LOG(FATAL) << "Null should not exist in AdjointMsg.";
+      }
+      return leaf_expr.value();
+    });
+  }
+
+  static AdjointMsg ExprToAdjointMsg(Expr expr) {
+    return MapToNestedMsgBySInfo<Expr>(expr, [](Expr leaf) {
+      ICHECK(GetStructInfoAs<TensorStructInfoNode>(leaf))
+          << "The leaf of adjoint: " << leaf << " should have StructInfo and 
be a Tensor.";
+      return AdjointMsg(leaf);
+    });
+  }
+
+  // Create a zeros AdjointMsg with specified struct info
+  // When sinfo is TupleStructInfo, we would create a nested zeros Tuple
+  static AdjointMsg InitZerosAdjointNested(const StructInfo& sinfo) {
+    return MapToNestedMsg<Expr>(sinfo, [](StructInfo sinfo) {
+      auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>();
+      ICHECK(tensor_sinfo) << "The leaf of adjoint should be a Tensor.";
+      ICHECK(tensor_sinfo->shape.defined()) << "Error: missing shape when 
building zeros tuple.";
+      const Expr& init = zeros(tensor_sinfo->shape.value(), 
tensor_sinfo->dtype);
+      UpdateStructInfo(init, sinfo);
+      return init;
+    });
+  }
+
+  // Return base + increment. A tuple-aware addition.
+  static AdjointMsg TupleAwareAdd(const AdjointMsg& base, const AdjointMsg& 
increment) {
+    return CombineNestedMsg(base, increment, [&](Expr lhs, Expr rhs) {
+      // a small optimization: a+0=a, 0+a=a.
+      if (IsCallZeros(lhs)) {
+        return rhs;
+      } else if (IsCallZeros(rhs)) {
+        return lhs;
+      }
+      auto* sinfo = GetStructInfoAs<TensorStructInfoNode>(lhs);
+      ICHECK(sinfo) << "The leaf of adjoint should have StructInfo and be a 
Tensor.";
+      ICHECK(GetStructInfoAs<TensorStructInfoNode>(rhs))
+          << "The leaf of adjoint should have StructInfo and be a Tensor.";
+      Expr res = add(lhs, rhs);
+      UpdateStructInfo(res, GetRef<StructInfo>(sinfo));
+      return res;
+    });
+  }
+
+  // Perform an addition in a specified position of tuple.
+  // e.g. tuple=(a, b, c), index=1, increment=d, then return (a, b+d, c)
+  static AdjointMsg AddInAdjointMsg(const AdjointMsg& adjoint, int index,
+                                    const AdjointMsg& increment) {
+    ICHECK(adjoint.IsNested()) << "The adjoint should be nested.";
+    Array<AdjointMsg> arr = adjoint.NestedArray();
+    ICHECK(index >= 0 && index < static_cast<int>(arr.size()));
+    arr.Set(index, TupleAwareAdd(arr[index], increment));
+    return AdjointMsg(arr);
+  }
+
+  // The block builder of the corresponding GradientMutator, to emit bindings
+  BlockBuilder builder_;
+  // Forward Var to its adjoint Var
+  Map<Var, Var> adjoint_var_map_;
+  // Forward Var to its adjoint NestedMsg<Expr>
+  // We use NestedMsg<Expr> to save the adjoint information (equivalent to 
adjoint Expr)
+  // When emitting, adjoint information will be transformed into adjoint Expr
+  Map<Var, AdjointMsg> adjoint_msg_map_;
+};
+
+class GradientMutator : private ExprMutator {
+ public:
+  static IRModule Transform(IRModule mod, String func_name, 
Optional<Array<Var>> require_grads,
+                            int target_index) {
+    auto* old_func_ptr = mod->Lookup(func_name).as<FunctionNode>();
+    CHECK(old_func_ptr) << func_name << "is not a Relax Function";
+    auto old_func = GetRef<Function>(old_func_ptr);
+
+    // when require_grads is not specified, it would be set to all params of 
the function
+    auto require_grads_value = require_grads.value_or(old_func->params);
+
+    CheckRequireGrads(require_grads_value, old_func->params, func_name);
+
+    Function new_func = CopyWithNewVars(old_func);
+    // map the parameter list into new params
+    for (size_t i = 0; i < require_grads_value.size(); ++i) {
+      int idx =
+          std::find(old_func->params.begin(), old_func->params.end(), 
require_grads_value[i]) -
+          old_func->params.begin();
+      require_grads_value.Set(i, new_func->params[idx]);
+    }
+
+    GradientMutator mutator(mod, require_grads_value, target_index);
+    Function new_func_transformed = 
Downcast<Function>(mutator.VisitExpr(new_func));
+
+    IRModule new_module = GetRef<IRModule>(mod.CopyOnWrite());
+    new_module->Add(GlobalVar(func_name + "_adjoint"), new_func_transformed);
+    return new_module;
+  }
+
+ private:
+  GradientMutator(const IRModule& module, const Array<Var>& require_grads, int 
target_index)
+      : ExprMutator(module), require_grads_(require_grads), 
target_index_(target_index) {}
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    CHECK(func->body->IsInstance<SeqExprNode>()) << "The body of the function 
must be SeqExpr.";
+
+    Expr new_body = this->VisitExpr(func->body);
+
+    return Function(func->params, new_body, NullOpt, func->attrs);
+  }
+
+  Expr VisitExpr_(const SeqExprNode* seq_expr) final {
+    // TODO(chaofan, yixin): multiple blocks AD
+    CHECK(seq_expr->blocks.size() == 1) << "now only support one dataflow 
block";
+    // TODO(chaofan, yixin): AD in non-dataflow block.
+    CHECK(seq_expr->blocks[0]->IsInstance<DataflowBlockNode>())
+        << "now only support one dataflow block";
+
+    // the return value should be a VarNode, and a scalar
+    orig_return_expr_ = seq_expr->body;
+    CheckAndSetTarget(seq_expr->body, target_index_);
+
+    BindingBlock new_block = this->VisitBindingBlock(seq_expr->blocks[0]);
+    return SeqExpr({new_block}, this->return_expr_);
+  }
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
+    builder_->BeginDataflowBlock();
+    // accept bindings in the original block
+    for (const auto& binding : block->bindings) {
+      this->VisitBinding(binding);
+    }
+
+    // generate backward bindings and the return value
+    return_expr_ = BackwardBindingGenerator::Generate(this->builder_, 
GetRef<DataflowBlock>(block),
+                                                      this->require_grads_, 
this->target_var_,
+                                                      orig_return_expr_);
+
+    return builder_->EndBlock();
+  }
+
+  static bool IsFloatTensorSInfo(const StructInfo& sinfo) {
+    auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>();
+    return tensor_sinfo && tensor_sinfo->dtype.is_float();
+  }
+
+  // When the return value is a Var, it is the target;
+  // when the return value is a Tuple, the target is the target_index-th field 
of the return value
+  // Check that the target should be a Var of scalar tensor struct_info
+  void CheckAndSetTarget(const Expr& e, int target_index) {
+    if (auto* var = e.as<VarNode>()) {
+      CHECK_EQ(target_index, 0) << "When the function has only one return 
value, target_index can "
+                                   "only be 0. But the target_index specified 
is "
+                                << target_index;
+      target_var_ = GetRef<Var>(var);
+    } else if (auto* tuple = e.as<TupleNode>()) {
+      CHECK(target_index >= 0 && target_index < 
static_cast<int>(tuple->fields.size()))
+          << "target_index should be in the range of the number of return 
values of the function. "
+             "But the specified target_index is "
+          << target_index << ", while the number of return values is " << 
tuple->fields.size();
+      auto* var = tuple->fields[target_index].as<VarNode>();
+      CHECK(var) << "Target must be a Var, but the specified target is "
+                 << tuple->fields[target_index];
+      target_var_ = GetRef<Var>(var);
+    } else {
+      LOG(FATAL) << "The return value of the function must be Var or Tuple. 
However, the return "
+                    "value of the given function is "
+                 << e;
+    }
+    auto target_sinfo = GetStructInfo(target_var_);
+    CHECK(IsScalarTensor(target_sinfo) && IsFloatTensorSInfo(target_sinfo))
+        << "The differentiation target must be a float scalar (0-dim Tensor), 
but the StructInfo "
+           "of the given target "
+        << target_var_ << " is " << GetStructInfo(target_var_);
+  }
+
+  // Check every Var in require_grads:
+  // 1. there should be no duplicate var
+  // 2. every var should be a parameter of the function
+  // 3. the type of the input var should be Tensor of floating point dtype, or 
Tuple of that
+  static void CheckRequireGrads(const Array<Var>& require_grads, const 
Array<Var>& func_params,
+                                const String& func_name) {
+    std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_set;
+    for (const auto& var : require_grads) {
+      CHECK(std::find(func_params.begin(), func_params.end(), var) != 
func_params.end())
+          << "There is no Var named " << var->name_hint() << " in the 
parameters of the function "
+          << func_name;
+      CHECK_EQ(var_set.count(var), 0) << "Var " << var->name_hint() << " 
appears more than once";
+      var_set.emplace(var);
+
+      CHECK(IsNestedTensorConditioned(GetStructInfo(var), IsFloatTensorSInfo))
+          << "Only Tensors of floating point dtype or Tuples of float "
+             "Tensors can require gradients, but the StructInfo of Var "
+          << var->name_hint() << " is " << GetStructInfo(var);
+    }
+  }
+
+  // differentiation sources
+  Array<Var> require_grads_;
+  // the differentiation target
+  int target_index_;
+  Var target_var_;
+  // the return value of the original function and the differentiated function
+  Expr orig_return_expr_;
+  Expr return_expr_;
+};
+
+namespace transform {
+
+Pass Gradient(String func_name, Optional<Array<Var>> require_grads, int 
target_index) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule mod,
+                                                                            
PassContext pc) {
+    return relax::GradientMutator::Transform(mod, func_name, require_grads, 
target_index);
+  };
+  return CreateModulePass(/*pass_function=*/pass_func,
+                          /*opt_level=*/0,
+                          /*pass_name=*/"Gradient",
+                          /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.Gradient").set_body_typed(Gradient);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/transform/utils.cc b/src/relax/transform/utils.cc
index 9a19115f62..c0fde3bd4c 100644
--- a/src/relax/transform/utils.cc
+++ b/src/relax/transform/utils.cc
@@ -22,6 +22,19 @@
 namespace tvm {
 namespace relax {
 
+bool IsScalarTensor(const StructInfo& sinfo) {
+  if (!sinfo->IsInstance<TensorStructInfoNode>()) {
+    return false;
+  }
+  TensorStructInfo tensor_sinfo = Downcast<TensorStructInfo>(sinfo);
+  if (!tensor_sinfo->shape.defined() || 
!tensor_sinfo->shape->IsInstance<ShapeExprNode>()) {
+    return false;
+  }
+  return tensor_sinfo->shape.as<ShapeExprNode>()->values.size() == 0;
+}
+
+bool IsScalarTensor(const Expr& expr) { return 
IsScalarTensor(GetStructInfo(expr)); }
+
 bool IsNestedTensor(const StructInfo& sinfo) {
   return IsNestedTensorConditioned(sinfo, [](const TensorStructInfo& sinfo) { 
return true; });
 }
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index d51fe53101..9334fd8347 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -143,6 +143,22 @@ IRModule MakeGroupedFunctions(
     const std::unordered_map<const Object*, relay::GraphPartitioner::Group*>& 
partition,
     bool lift_constants = true);
 
+/*!
+ * \brief Check if the given StructInfo is a scalar tensor. The sinfo should 
be an instance of
+ * TensorStructInfo; its shape must be ShapeExpr.
+ * \param sinfo The StructInfo to be checked.
+ * \return true if the given StructInfo is a scalar tensor.
+ */
+bool IsScalarTensor(const StructInfo& sinfo);
+
+/*!
+ * \brief Check if the given expr is a scalar tensor. Now the shape of the 
tensor expr must be
+ * ShapeExpr.
+ * \param expr The expr to be checked.
+ * \return true if the given expr is a scalar tensor.
+ */
+bool IsScalarTensor(const Expr& expr);
+
 /*!
  * \brief Check if the given StructInfo is a nested tensor StructInfo 
satisfying the given
  * condition f_condition.
diff --git a/tests/python/relax/test_transform_gradient.py 
b/tests/python/relax/test_transform_gradient.py
new file mode 100644
index 0000000000..1b3d174c13
--- /dev/null
+++ b/tests/python/relax/test_transform_gradient.py
@@ -0,0 +1,1164 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import pytest
+import tvm.testing
+from tvm import relax
+from tvm.ir.base import assert_structural_equal
+from tvm.script.parser import relax as R, tir as T, ir as I
+from tvm._ffi.base import TVMError
+import numpy as np
+
+
+def test_simple():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                gv = R.sum(x)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor(None, "float32", 
ndim=0):
+            with R.dataflow():
+                gv: R.Tensor((), "float32") = R.sum(x, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 3), "float32")) -> 
R.Tuple(R.Tensor(None, "float32", ndim=0),R.Tuple(R.Tensor(None, "float32", 
ndim=2)),):
+            with R.dataflow():
+                gv: R.Tensor((), "float32") = R.sum(x, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                x_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, (3, 3))
+                R.output(gv, x_adjoint)
+            return (gv, (x_adjoint,))
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_assign_binding():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                lv1 = x
+                lv2 = lv1
+                gv = R.sum(lv2)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = x
+                lv2: R.Tensor((3, 3), "float32") = lv1
+                gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 3), "float32")) -> 
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = x
+                lv2: R.Tensor((3, 3), "float32") = lv1
+                gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                lv2_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, (3, 3))
+                lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint
+                x_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint
+                R.output(gv, x_adjoint)
+            return (gv, (x_adjoint,))
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_multiple_uses():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                lv1 = R.add(x, x)
+                lv2 = R.add(lv1, x)
+                gv = R.sum(lv2)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = R.add(x, x)
+                lv2: R.Tensor((3, 3), "float32") = R.add(lv1, x)
+                gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 3), "float32")) -> 
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = R.add(x, x)
+                lv2: R.Tensor((3, 3), "float32") = R.add(lv1, x)
+                gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                lv2_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, (3, 3))
+                lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint
+                lv: R.Tensor((3, 3), "float32") = R.add(lv2_adjoint, 
lv1_adjoint)
+                x_adjoint: R.Tensor((3, 3), "float32") = R.add(lv, lv1_adjoint)
+                R.output(gv, x_adjoint)
+            return (gv, (x_adjoint,))
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_unused():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                lv1 = R.add(x, x)
+                lv2 = R.add(lv1, x)
+                gv = R.sum(x)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = R.add(x, x)
+                lv2: R.Tensor((3, 3), "float32") = R.add(lv1, x)
+                gv: R.Tensor((), "float32") = R.sum(x, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 3), "float32")) -> 
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = R.add(x, x)
+                lv2: R.Tensor((3, 3), "float32") = R.add(lv1, x)
+                gv: R.Tensor((), "float32") = R.sum(x, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                x_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, (3, 3))
+                R.output(gv, x_adjoint)
+            return (gv, (x_adjoint,))
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_default_require_grads():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((3, 3), "float32"),
+            y: R.Tensor((3, 3), "float32"),
+            z: R.Tensor((3, 3), "float32"),
+        ):
+            with R.dataflow():
+                lv1 = R.add(x, y)
+                lv2 = R.add(lv1, z)
+                gv = R.sum(lv2)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected1:
+        @R.function
+        def main(
+            x: R.Tensor((3, 3), "float32"),
+            y: R.Tensor((3, 3), "float32"),
+            z: R.Tensor((3, 3), "float32"),
+        ) -> R.Tensor((), "float32"):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = R.add(x, y)
+                lv2: R.Tensor((3, 3), "float32") = R.add(lv1, z)
+                gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32"), z: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), 
R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32"))):
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = R.add(x, y)
+                lv2: R.Tensor((3, 3), "float32") = R.add(lv1, z)
+                gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                lv2_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, (3, 3))
+                lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint
+                x_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint
+                y_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint
+                z_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint
+                R.output(gv, x_adjoint, y_adjoint, z_adjoint)
+            return (gv, (x_adjoint, y_adjoint, z_adjoint))
+    # fmt: on
+
+    After1 = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After1, Expected1)
+
+    # fmt: off
+    @I.ir_module
+    class Expected2:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32"), z: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = R.add(x, y)
+                lv2: R.Tensor((3, 3), "float32") = R.add(lv1, z)
+                gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32"), z: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), 
R.Tuple(R.Tensor((3, 3), "float32"))):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = R.add(x, y)
+                lv2: R.Tensor((3, 3), "float32") = R.add(lv1, z)
+                gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                lv2_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, (3, 3))
+                lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint
+                x_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint
+                R.output(gv, x_adjoint)
+            return (gv, (x_adjoint,))
+    # fmt: on
+
+    After2 = relax.transform.Gradient("main", 
require_grads=Before["main"].params[0])(Before)
+    assert_structural_equal(After2, Expected2)
+
+
+def test_target_index():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                lv1 = x
+                lv2 = R.sum(x)
+                lv3 = R.sum(y)
+                R.output(lv1, lv2, lv3)
+            return (lv1, lv2, lv3)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")) -> R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((), "float32"), 
R.Tensor((), "float32")):
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = x
+                lv2: R.Tensor((), "float32") = R.sum(x, axis=None, 
keepdims=False)
+                lv3: R.Tensor((), "float32") = R.sum(y, axis=None, 
keepdims=False)
+                R.output(lv1, lv2, lv3)
+            return (lv1, lv2, lv3)
+
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")) -> R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((), 
"float32"), R.Tensor((), "float32")), R.Tuple(R.Tensor((3, 3), "float32"), 
R.Tensor((3, 3), "float32"))):
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = x
+                lv2: R.Tensor((), "float32") = R.sum(x, axis=None, 
keepdims=False)
+                lv3: R.Tensor((), "float32") = R.sum(y, axis=None, 
keepdims=False)
+                lv3_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                x_adjoint: R.Tensor((3, 3), "float32") = R.zeros((3, 3), 
"float32")
+                y_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(lv3_adjoint, (3, 3))
+                R.output(lv1, lv2, lv3, x_adjoint, y_adjoint)
+            return ((lv1, lv2, lv3), (x_adjoint, y_adjoint))
+    # fmt: on
+
+    After = relax.transform.Gradient("main", target_index=2)(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_tuple():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")),
+            y: R.Tensor((3, 3), "float32"),
+            z: R.Tensor((3, 3), "float32"),
+        ):
+            with R.dataflow():
+                lv1 = (y, z)
+                lv2 = x[0]
+                lv3 = lv1[0]
+                lv4 = R.add(lv2, lv3)
+                gv = R.sum(lv4)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")) -> 
R.Tensor(None, "float32", ndim=0):
+            with R.dataflow():
+                lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = (y, z)
+                lv2: R.Tensor((3, 3), "float32") = x[0]
+                lv3: R.Tensor((3, 3), "float32") = lv1[0]
+                lv4: R.Tensor((3, 3), "float32") = R.add(lv2, lv3)
+                gv: R.Tensor((), "float32") = R.sum(lv4, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), 
"float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tuple(R.Tensor((3, 
3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32"), 
R.Tensor((3, 3), "float32"))):
+            with R.dataflow():
+                lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = (y, z)
+                lv2: R.Tensor((3, 3), "float32") = x[0]
+                lv3: R.Tensor((3, 3), "float32") = lv1[0]
+                lv4: R.Tensor((3, 3), "float32") = R.add(lv2, lv3)
+                gv: R.Tensor((), "float32") = R.sum(lv4, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                lv4_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, (3, 3))
+                lv3_adjoint: R.Tensor((3, 3), "float32") = lv4_adjoint
+                lv2_adjoint: R.Tensor((3, 3), "float32") = lv4_adjoint
+                lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+                lv1_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")) = (lv3_adjoint, lv)
+                lv11: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+                x_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")) = (lv2_adjoint, lv11)
+                y_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint[0]
+                z_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint[1]
+                R.output(gv, x_adjoint, y_adjoint, z_adjoint)
+            return (gv, (x_adjoint, y_adjoint, z_adjoint))
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_tuple_assignment():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                lv1 = (x, y)
+                lv4 = lv1[0]
+                lv7 = R.add(lv4, x)
+                lv2 = lv1
+                lv3 = lv2[0]
+                lv5 = R.add(lv3, lv7)
+                gv = R.sum(lv5)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")) -> R.Tensor((), "float32"):
+            with R.dataflow():
+                lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = (x, y)
+                lv4: R.Tensor((3, 3), "float32") = lv1[0]
+                lv7: R.Tensor((3, 3), "float32") = R.add(lv4, x)
+                lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = lv1
+                lv3: R.Tensor((3, 3), "float32") = lv2[0]
+                lv5: R.Tensor((3, 3), "float32") = R.add(lv3, lv7)
+                gv: R.Tensor((), "float32") = R.sum(lv5, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), 
"float32"), R.Tensor((3, 3), "float32"))):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = (x, y)
+                lv4: R.Tensor((3, 3), "float32") = lv1[0]
+                lv7: R.Tensor((3, 3), "float32") = R.add(lv4, x)
+                lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = lv1
+                lv3: R.Tensor((3, 3), "float32") = lv2[0]
+                lv5: R.Tensor((3, 3), "float32") = R.add(lv3, lv7)
+                gv: R.Tensor((), "float32") = R.sum(lv5, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                lv5_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, (3, 3))
+                lv3_adjoint: R.Tensor((3, 3), "float32") = lv5_adjoint
+                lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+                lv2_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")) = (lv3_adjoint, lv)
+                lv7_adjoint: R.Tensor((3, 3), "float32") = lv5_adjoint
+                lv4_adjoint: R.Tensor((3, 3), "float32") = lv7_adjoint
+                lv11: R.Tensor((3, 3), "float32") = lv2_adjoint[0]
+                lv21: R.Tensor((3, 3), "float32") = R.add(lv11, lv4_adjoint)
+                lv31: R.Tensor((3, 3), "float32") = lv2_adjoint[1]
+                lv1_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")) = (lv21, lv31)
+                lv41: R.Tensor((3, 3), "float32") = lv1_adjoint[0]
+                x_adjoint: R.Tensor((3, 3), "float32") = R.add(lv7_adjoint, 
lv41)
+                y_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint[1]
+                R.output(gv, x_adjoint, y_adjoint)
+            return (gv, (x_adjoint, y_adjoint))
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_tuple_nested():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")), R.Tensor((3, 3), "float32")),
+            y: R.Tensor((3, 3), "float32"),
+            z: R.Tensor((3, 3), "float32"),
+            u: R.Tensor((3, 3), "float32"),
+     ):
+            with R.dataflow():
+                lv1 = ((y, z), u)
+                lv2 = x[0]
+                lv3 = lv2[0]
+                lv4 = lv1[0]
+                lv5 = lv4[1]
+                lv6 = R.add(lv3, lv5)
+                lv7 = x[1]
+                lv8 = R.add(lv6, lv7)
+                gv = R.sum(lv8)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")), R.Tensor((3, 3), "float32")), y: R.Tensor((3, 3), "float32"), 
z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32")) -> R.Tensor((), 
"float32"):
+            with R.dataflow():
+                lv1: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")), R.Tensor((3, 3), "float32")) = ((y, z), u)
+                lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = x[0]
+                lv3: R.Tensor((3, 3), "float32") = lv2[0]
+                lv4: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = lv1[0]
+                lv5: R.Tensor((3, 3), "float32") = lv4[1]
+                lv6: R.Tensor((3, 3), "float32") = R.add(lv3, lv5)
+                lv7: R.Tensor((3, 3), "float32") = x[1]
+                lv8: R.Tensor((3, 3), "float32") = R.add(lv6, lv7)
+                gv: R.Tensor((), "float32") = R.sum(lv8, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), 
R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")), y: R.Tensor((3, 3), 
"float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32")) -> 
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tuple(R.Tuple(R.Tensor((3, 3), 
"float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")), 
R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32"))):
+            with R.dataflow():
+                lv1: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")), R.Tensor((3, 3), "float32")) = ((y, z), u)
+                lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = x[0]
+                lv3: R.Tensor((3, 3), "float32") = lv2[0]
+                lv4: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = lv1[0]
+                lv5: R.Tensor((3, 3), "float32") = lv4[1]
+                lv6: R.Tensor((3, 3), "float32") = R.add(lv3, lv5)
+                lv7: R.Tensor((3, 3), "float32") = x[1]
+                lv8: R.Tensor((3, 3), "float32") = R.add(lv6, lv7)
+                gv: R.Tensor((), "float32") = R.sum(lv8, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                lv8_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, (3, 3))
+                lv7_adjoint: R.Tensor((3, 3), "float32") = lv8_adjoint
+                lv6_adjoint: R.Tensor((3, 3), "float32") = lv8_adjoint
+                lv5_adjoint: R.Tensor((3, 3), "float32") = lv6_adjoint
+                lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+                lv4_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")) = (lv, lv5_adjoint)
+                lv3_adjoint: R.Tensor((3, 3), "float32") = lv6_adjoint
+                lv11: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+                lv2_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")) = (lv3_adjoint, lv11)
+                lv21: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+                lv1_adjoint: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), 
R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")) = (lv4_adjoint, lv21)
+                x_adjoint: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), 
R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")) = (lv2_adjoint, 
lv7_adjoint)
+                lv31: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = lv1_adjoint[0]
+                y_adjoint: R.Tensor((3, 3), "float32") = lv31[0]
+                z_adjoint: R.Tensor((3, 3), "float32") = lv31[1]
+                u_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint[1]
+                R.output(gv, x_adjoint, y_adjoint, z_adjoint, u_adjoint)
+            return (gv, (x_adjoint, y_adjoint, z_adjoint, u_adjoint))
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_tuple_update():
+    """One tensor `x` is used in and out of tuple many times."""
+
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                lv0 = (x, y)
+                lv1 = R.add(x, y)
+                lv2 = lv0[0]
+                lv3 = R.add(lv2, y)
+                lv4 = R.add(lv1, lv3)
+                lv5 = (x, y)
+                lv6 = lv5[0]
+                lv7 = lv0[0]
+                lv8 = R.add(lv4, lv6)
+                lv9 = R.add(lv8, lv7)
+                gv = R.sum(lv9)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")) -> R.Tensor((), "float32"):
+            with R.dataflow():
+                lv0: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = (x, y)
+                lv1: R.Tensor((3, 3), "float32") = R.add(x, y)
+                lv2: R.Tensor((3, 3), "float32") = lv0[0]
+                lv3: R.Tensor((3, 3), "float32") = R.add(lv2, y)
+                lv4: R.Tensor((3, 3), "float32") = R.add(lv1, lv3)
+                lv5: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = (x, y)
+                lv6: R.Tensor((3, 3), "float32") = lv5[0]
+                lv7: R.Tensor((3, 3), "float32") = lv0[0]
+                lv8: R.Tensor((3, 3), "float32") = R.add(lv4, lv6)
+                lv9: R.Tensor((3, 3), "float32") = R.add(lv8, lv7)
+                gv: R.Tensor((), "float32") = R.sum(lv9, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), 
"float32"), R.Tensor((3, 3), "float32"))):
+            with R.dataflow():
+                lv0: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = (x, y)
+                lv1: R.Tensor((3, 3), "float32") = R.add(x, y)
+                lv2: R.Tensor((3, 3), "float32") = lv0[0]
+                lv3: R.Tensor((3, 3), "float32") = R.add(lv2, y)
+                lv4: R.Tensor((3, 3), "float32") = R.add(lv1, lv3)
+                lv5: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = (x, y)
+                lv6: R.Tensor((3, 3), "float32") = lv5[0]
+                lv7: R.Tensor((3, 3), "float32") = lv0[0]
+                lv8: R.Tensor((3, 3), "float32") = R.add(lv4, lv6)
+                lv9: R.Tensor((3, 3), "float32") = R.add(lv8, lv7)
+                gv: R.Tensor((), "float32") = R.sum(lv9, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                lv9_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, (3, 3))
+                lv8_adjoint: R.Tensor((3, 3), "float32") = lv9_adjoint
+                lv7_adjoint: R.Tensor((3, 3), "float32") = lv9_adjoint
+                lv6_adjoint: R.Tensor((3, 3), "float32") = lv8_adjoint
+                lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+                lv5_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")) = (lv6_adjoint, lv)
+                lv4_adjoint: R.Tensor((3, 3), "float32") = lv8_adjoint
+                lv3_adjoint: R.Tensor((3, 3), "float32") = lv4_adjoint
+                lv2_adjoint: R.Tensor((3, 3), "float32") = lv3_adjoint
+                lv1_adjoint: R.Tensor((3, 3), "float32") = lv4_adjoint
+                lv11: R.Tensor((3, 3), "float32") = R.add(lv7_adjoint, 
lv2_adjoint)
+                lv21: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+                lv0_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")) = (lv11, lv21)
+                lv31: R.Tensor((3, 3), "float32") = lv5_adjoint[0]
+                lv41: R.Tensor((3, 3), "float32") = R.add(lv31, lv1_adjoint)
+                lv51: R.Tensor((3, 3), "float32") = lv0_adjoint[0]
+                x_adjoint: R.Tensor((3, 3), "float32") = R.add(lv41, lv51)
+                lv61: R.Tensor((3, 3), "float32") = lv5_adjoint[1]
+                lv71: R.Tensor((3, 3), "float32") = R.add(lv61, lv3_adjoint)
+                lv81: R.Tensor((3, 3), "float32") = R.add(lv71, lv1_adjoint)
+                lv91: R.Tensor((3, 3), "float32") = lv0_adjoint[1]
+                y_adjoint: R.Tensor((3, 3), "float32") = R.add(lv81, lv91)
+                R.output(gv, x_adjoint, y_adjoint)
+            return (gv, (x_adjoint, y_adjoint))
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_tuple_op_simple():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((6,), "float32")):
+            with R.dataflow():
+                lv1 = R.split(x, 2)
+                lv2 = R.concat(lv1)
+                gv = R.sum(lv2)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((6,), "float32")) -> R.Tensor((), "float32"):
+            with R.dataflow():
+                lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), 
"float32")) = R.split(x, indices_or_sections=2, axis=0)
+                lv2: R.Tensor((6,), "float32") = R.concat(lv1, axis=0)
+                gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tensor((6,), "float32")) -> R.Tuple(R.Tensor((), 
"float32"), R.Tuple(R.Tensor((6,), "float32"))):
+            with R.dataflow():
+                lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), 
"float32")) = R.split(x, indices_or_sections=2, axis=0)
+                lv2: R.Tensor((6,), "float32") = R.concat(lv1, axis=0)
+                gv: R.Tensor((), "float32") = R.sum(lv2, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                lv2_adjoint: R.Tensor((6,), "float32") = 
R.broadcast_to(gv_adjoint, (6,))
+                lv1_adjoint: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), 
"float32")) = R.split(lv2_adjoint, indices_or_sections=[3], axis=0)
+                x_adjoint: R.Tensor((6,), "float32") = R.concat(lv1_adjoint, 
axis=0)
+                R.output(gv, x_adjoint)
+            return (gv, (x_adjoint,))
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_tuple_op_construct():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3,), "float32"), y: R.Tuple(R.Tensor((3, ), 
"float32"), R.Tensor((3, ), "float32")),):
+            with R.dataflow():
+                lv1 = (x, x)
+                lv2 = R.concat(lv1)
+                lv3 = R.concat((x, x))
+                lv4 = R.concat(y)
+                lv5 = R.add(lv2, lv3)
+                lv6 = R.add(lv5, lv4)
+                gv = R.sum(lv6)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3,), "float32"), y: R.Tuple(R.Tensor((3,), 
"float32"), R.Tensor((3,), "float32"))) -> R.Tensor((), "float32"):
+            with R.dataflow():
+                lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), 
"float32")) = (x, x)
+                lv2: R.Tensor((6,), "float32") = R.concat(lv1, axis=0)
+                lv3: R.Tensor((6,), "float32") = R.concat((x, x), axis=0)
+                lv4: R.Tensor((6,), "float32") = R.concat(y, axis=0)
+                lv5: R.Tensor((6,), "float32") = R.add(lv2, lv3)
+                lv6: R.Tensor((6,), "float32") = R.add(lv5, lv4)
+                gv: R.Tensor((), "float32") = R.sum(lv6, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tensor((3,), "float32"), y: 
R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32"))) -> 
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3,), "float32"), 
R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")))):
+            with R.dataflow():
+                lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), 
"float32")) = (x, x)
+                lv2: R.Tensor((6,), "float32") = R.concat(lv1, axis=0)
+                lv3: R.Tensor((6,), "float32") = R.concat((x, x), axis=0)
+                lv4: R.Tensor((6,), "float32") = R.concat(y, axis=0)
+                lv5: R.Tensor((6,), "float32") = R.add(lv2, lv3)
+                lv6: R.Tensor((6,), "float32") = R.add(lv5, lv4)
+                gv: R.Tensor((), "float32") = R.sum(lv6, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                lv6_adjoint: R.Tensor((6,), "float32") = 
R.broadcast_to(gv_adjoint, (6,))
+                lv5_adjoint: R.Tensor((6,), "float32") = lv6_adjoint
+                lv4_adjoint: R.Tensor((6,), "float32") = lv6_adjoint
+                lv3_adjoint: R.Tensor((6,), "float32") = lv5_adjoint
+                lv2_adjoint: R.Tensor((6,), "float32") = lv5_adjoint
+                lv1_adjoint: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), 
"float32")) = R.split(lv2_adjoint, indices_or_sections=[3], axis=0)
+                lv: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), 
"float32")) = R.split(lv3_adjoint, indices_or_sections=[3], axis=0)
+                lv11: R.Tensor((3,), "float32") = lv[0]
+                lv21: R.Tensor((3,), "float32") = lv[1]
+                lv31: R.Tensor((3,), "float32") = R.add(lv11, lv21)
+                lv41: R.Tensor((3,), "float32") = lv1_adjoint[0]
+                lv51: R.Tensor((3,), "float32") = R.add(lv31, lv41)
+                lv61: R.Tensor((3,), "float32") = lv1_adjoint[1]
+                x_adjoint: R.Tensor((3,), "float32") = R.add(lv51, lv61)
+                y_adjoint: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), 
"float32")) = R.split(lv4_adjoint, indices_or_sections=[3], axis=0)
+                R.output(gv, x_adjoint, y_adjoint)
+            return (gv, (x_adjoint, y_adjoint))
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_tuple_op_const():
+    c1 = R.const(np.zeros(3).astype(np.float32))
+    c2 = R.const(np.zeros(3).astype(np.float32))
+    c3 = R.const(np.zeros(3).astype(np.float32))
+
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3,), "float32")):
+            with R.dataflow():
+                lv1 = R.concat((c1, c2))
+                lv2 = R.concat((c3, x))
+                lv3 = R.concat((x, x))
+                lv4 = R.add(lv1, lv2)
+                lv5 = R.add(lv4, lv3)
+                gv = R.sum(lv5)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3,), "float32")) -> R.Tensor((), "float32"):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((6,), "float32") = R.concat((c1, c2), axis=0)
+                lv2: R.Tensor((6,), "float32") = R.concat((c3, x), axis=0)
+                lv3: R.Tensor((6,), "float32") = R.concat((x, x), axis=0)
+                lv4: R.Tensor((6,), "float32") = R.add(lv1, lv2)
+                lv5: R.Tensor((6,), "float32") = R.add(lv4, lv3)
+                gv: R.Tensor((), "float32") = R.sum(lv5, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tensor((3,), "float32")) -> R.Tuple(R.Tensor((), 
"float32"), R.Tuple(R.Tensor((3,), "float32"))):
+            # block 0
+            with R.dataflow():
+                lv1: R.Tensor((6,), "float32") = R.concat((c1, c2), axis=0)
+                lv2: R.Tensor((6,), "float32") = R.concat((c3, x), axis=0)
+                lv3: R.Tensor((6,), "float32") = R.concat((x, x), axis=0)
+                lv4: R.Tensor((6,), "float32") = R.add(lv1, lv2)
+                lv5: R.Tensor((6,), "float32") = R.add(lv4, lv3)
+                gv: R.Tensor((), "float32") = R.sum(lv5, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                lv5_adjoint: R.Tensor((6,), "float32") = 
R.broadcast_to(gv_adjoint, (6,))
+                lv4_adjoint: R.Tensor((6,), "float32") = lv5_adjoint
+                lv3_adjoint: R.Tensor((6,), "float32") = lv5_adjoint
+                lv2_adjoint: R.Tensor((6,), "float32") = lv4_adjoint
+                lv1_adjoint: R.Tensor((6,), "float32") = lv4_adjoint
+                lv: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), 
"float32")) = R.split(lv3_adjoint, indices_or_sections=[3], axis=0)
+                lv11: R.Tensor((3,), "float32") = lv[0]
+                lv21: R.Tensor((3,), "float32") = lv[1]
+                lv31: R.Tensor((3,), "float32") = R.add(lv11, lv21)
+                lv41: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), 
"float32")) = R.split(lv2_adjoint, indices_or_sections=[3], axis=0)
+                lv51: R.Tensor((3,), "float32") = lv41[1]
+                x_adjoint: R.Tensor((3,), "float32") = R.add(lv31, lv51)
+                R.output(gv, x_adjoint)
+            return (gv, (x_adjoint,))
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After["main_adjoint"], Expected["main_adjoint"])
+
+
+def test_const():
+    """const could be used in variable assignment, call argument, and as a 
part of tuple"""
+    cst = relax.const(np.ones((3, 3)), "float32")
+
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                lv1 = R.add(x, cst)
+                lv2 = cst
+                lv3 = (cst, (cst, lv1))
+                lv4 = lv3[1]
+                lv5 = lv4[1]
+                lv6 = R.subtract(lv5, lv2)
+                gv = R.sum(lv6)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")) -> R.Tensor((), "float32"):
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = R.add(x, cst)
+                lv2: R.Tensor((3, 3), "float32") = cst
+                lv3: R.Tuple(R.Tensor((3, 3), "float32"), R.Tuple(R.Tensor((3, 
3), "float32"), R.Tensor((3, 3), "float32"))) = (cst, (cst, lv1))
+                lv4: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = lv3[1]
+                lv5: R.Tensor((3, 3), "float32") = lv4[1]
+                lv6: R.Tensor((3, 3), "float32") = R.subtract(lv5, lv2)
+                gv: R.Tensor((), "float32") = R.sum(lv6, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), 
"float32"), R.Tensor((3, 3), "float32"))):
+            with R.dataflow():
+                lv1: R.Tensor((3, 3), "float32") = R.add(x, cst)
+                lv2: R.Tensor((3, 3), "float32") = cst
+                lv3: R.Tuple(R.Tensor((3, 3), "float32"), R.Tuple(R.Tensor((3, 
3), "float32"), R.Tensor((3, 3), "float32"))) = (cst, (cst, lv1))
+                lv4: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = lv3[1]
+                lv5: R.Tensor((3, 3), "float32") = lv4[1]
+                lv6: R.Tensor((3, 3), "float32") = R.subtract(lv5, lv2)
+                gv: R.Tensor((), "float32") = R.sum(lv6, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+                lv6_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, (3, 3))
+                lv5_adjoint: R.Tensor((3, 3), "float32") = lv6_adjoint
+                lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+                lv4_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")) = (lv, lv5_adjoint)
+                lv11: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+                lv3_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), 
R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))) = (lv11, 
lv4_adjoint)
+                lv2_adjoint: R.Tensor((3, 3), "float32") = 
R.negative(lv6_adjoint)
+                lv21: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = lv3_adjoint[1]
+                lv1_adjoint: R.Tensor((3, 3), "float32") = lv21[1]
+                x_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint
+                y_adjoint: R.Tensor((3, 3), "float32") = R.zeros((3, 3), 
"float32")
+                R.output(gv, x_adjoint, y_adjoint)
+            return (gv, (x_adjoint, y_adjoint))
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_params_copy():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x0: R.Tensor((3, 3), "float32"),
+            x1: R.Tensor((3, 3), "float32"),
+            x2: R.Tensor((3, 3), "float32"),
+            x3: R.Tensor((3, 3), "float32"),
+        ):
+            with R.dataflow():
+                lv0 = R.add(x0, x1)
+                lv1 = R.add(x2, x3)
+                lv2 = R.add(lv0, lv1)
+                gv = R.sum(lv2)
+                R.output(gv)
+            return gv
+
+    After = relax.transform.Gradient("main")(Before)
+    assert len(Before["main"].params) == len(After["main"].params)
+    assert len(Before["main"].params) == len(After["main_adjoint"].params)
+    for i in range(len(After["main"].params)):
+        assert Before["main"].params[i] == After["main"].params[i]
+        assert Before["main"].params[i] != After["main_adjoint"].params[i]
+
+
+def test_function_copy():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x0: R.Tensor((3, 3), "float32"),
+            x1: R.Tensor((3, 3), "float32"),
+            x2: R.Tensor((3, 3), "float32"),
+            x3: R.Tensor((3, 3), "float32"),
+        ):
+            with R.dataflow():
+                lv0 = R.add(x0, x1)
+                lv1 = R.add(x2, x3)
+                lv2 = R.add(lv0, lv1)
+                gv = R.sum(lv2)
+                R.output(gv)
+            return gv
+
+    After = relax.transform.Gradient("main")(Before)
+
+    # After should have the same "main" function as Before
+    assert_structural_equal(Before["main"], After["main"])
+
+    # the first bindings of After["main_adjoint"] should be the same as 
Before["main"]
+    old_bindings = Before["main"].body.blocks[0].bindings
+    old_bindings_len = len(old_bindings)
+    new_bindings = 
After["main_adjoint"].body.blocks[0].bindings[:old_bindings_len]
+    assert_structural_equal(old_bindings, new_bindings, True)
+
+
+def test_report_error():
+    @I.ir_module
+    class TargetNotTensor:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                lv1 = R.sum(x)
+                gv = R.tuple(lv1, lv1)
+                R.output(gv)
+            return gv
+
+    with pytest.raises(TVMError):
+        relax.transform.Gradient("main")(TargetNotTensor)
+
+    @I.ir_module
+    class TargetNotScalar:
+        @R.function
+        def main(x0: R.Tensor((3, 3), "float32"), x1: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                gv = R.add(x0, x1)
+                R.output(gv)
+            return gv
+
+    with pytest.raises(TVMError):
+        relax.transform.Gradient("main")(TargetNotScalar)
+
+    @I.ir_module
+    class TargetNotFloat:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                gv = R.const(1)
+                R.output(gv)
+            return gv
+
+    with pytest.raises(TVMError):
+        relax.transform.Gradient("main")(TargetNotFloat)
+
+    @I.ir_module
+    class ReturnScalarAndWrongTargetIndex:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                gv = R.sum(x)
+                R.output(gv)
+            return gv
+
+    with pytest.raises(TVMError):
+        relax.transform.Gradient("main", 
target_index=1)(ReturnScalarAndWrongTargetIndex)
+
+    @I.ir_module
+    class ReturnTupleAndWrongTargetIndex:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                gv1 = R.sum(x)
+                gv2 = R.sum(y)
+                R.output(gv1, gv2)
+            return gv1, gv2
+
+    with pytest.raises(TVMError):
+        relax.transform.Gradient("main", 
target_index=2)(ReturnTupleAndWrongTargetIndex)
+
+    @I.ir_module
+    class IndexedTargetNotVar:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                gv = R.sum(x)
+                R.output(gv)
+            return gv, (gv, gv)
+
+    with pytest.raises(TVMError):
+        relax.transform.Gradient("main", target_index=1)(IndexedTargetNotVar)
+
+    @I.ir_module
+    class NoDataflow:
+        @R.function
+        def main(x0: R.Tensor((3, 3), "float32")):
+            gv = R.sum(x0)
+            return gv
+
+    with pytest.raises(TVMError):
+        relax.transform.Gradient("main")(NoDataflow)
+
+    @I.ir_module
+    class MultiBlocks:
+        @R.function
+        def main(x0: R.Tensor((3, 3), "float32"), x1: R.Tensor((3, 3), 
"float32")):
+            # block 0
+            with R.dataflow():
+                gv = R.add(x0, x1)
+                R.output(gv)
+            # block 1
+            gv1 = R.sum(x0)
+            return gv1
+
+    with pytest.raises(TVMError):
+        relax.transform.Gradient("main")(MultiBlocks)
+
+    @I.ir_module
+    class NormalModule:
+        @R.function
+        def main(x0: R.Tensor((3, 3), "float32"), x1: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                gv = R.sum(x0)
+                R.output(gv)
+            return gv
+
+        @T.prim_func
+        def sum(
+            rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"),
+            rxplaceholder_red: T.Buffer((), "float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for k0, k1 in T.grid(T.int64(3), T.int64(3)):
+                with T.block("rxplaceholder_red"):
+                    v_k0, v_k1 = T.axis.remap("RR", [k0, k1])
+                    T.reads(rxplaceholder[v_k0, v_k1])
+                    T.writes(rxplaceholder_red[()])
+                    with T.init():
+                        rxplaceholder_red[()] = T.float32(0)
+                    rxplaceholder_red[()] = rxplaceholder_red[()] + 
rxplaceholder[v_k0, v_k1]
+
+    # no such function
+    with pytest.raises(ValueError):
+        relax.transform.Gradient("main1")(NormalModule)
+    # wrong function type
+    with pytest.raises(TVMError):
+        relax.transform.Gradient("sum")(NormalModule)
+    # no such var
+    with pytest.raises(TVMError):
+        relax.transform.Gradient("main", 
require_grads=MultiBlocks["main"].params[0])(NormalModule)
+
+    @I.ir_module
+    class IntDtype:
+        @R.function
+        def main(x: R.Tensor((3, 3), "int64")):
+            with R.dataflow():
+                lv1 = R.add(x, x)
+                gv = R.sum(lv1)
+                R.output(gv)
+            return gv
+
+    with pytest.raises(TVMError):
+        relax.transform.Gradient("main")(IntDtype)
+
+    @I.ir_module
+    class IntDtypeTuple:
+        @R.function
+        def main(x: R.Tuple(R.Tensor((3, 3), "int64"), R.Tensor((3, 3), 
"int64"))):
+            with R.dataflow():
+                lv1 = x[0]
+                lv2 = x[1]
+                lv3 = R.add(lv1, lv2)
+                gv = R.sum(lv3)
+                R.output(gv)
+            return gv
+
+    with pytest.raises(TVMError):
+        relax.transform.Gradient("main")(IntDtypeTuple)
+
+
+def test_shape_expr():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 4), "float32")):
+            with R.dataflow():
+                s = R.shape([3, 2, 2])
+                lv = R.reshape(x, s)
+                gv = R.sum(lv)
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 4), dtype="float32")) -> 
R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 4), 
dtype="float32"))):
+            with R.dataflow():
+                s: R.Shape([3, 2, 2]) = R.shape([3, 2, 2])
+                lv = R.reshape(x, s)
+                gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), dtype="float32") = 
R.ones(R.shape([]), dtype="float32")
+                lv_adjoint : R.Tensor([3, 2, 2], "float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 2, 2]))
+                x_adjoint: R.Tensor((3, 4), dtype="float32") = 
R.reshape(lv_adjoint, R.shape([3, 4]))
+                R.output(gv, x_adjoint)
+            return (gv, (x_adjoint,))
+
+        @R.function
+        def main(x: R.Tensor((3, 4), dtype="float32")) -> R.Tensor((), 
dtype="float32"):
+            with R.dataflow():
+                s: R.Shape([3, 2, 2]) = R.shape([3, 2, 2])
+                lv = R.reshape(x, s)
+                gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+    # fmt: on
+
+    After = relax.transform.Gradient("main")(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_mlp_script():
+    """
+    An example of single layer multi-layer perceptron. You can add extra 
layers if you want.
+
+    For n-layer perceptron, see test_transform_gradient_numeric.py.
+    """
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((3, 10), "float32"),
+            w0: R.Tensor((10, 5), "float32"),
+            b0: R.Tensor((5,), "float32"),
+            label: R.Tensor((3, 5), "float32"),
+        ):
+            with R.dataflow():
+                lv0 = R.matmul(x, w0)
+                out = R.add(lv0, b0)
+                logits = R.nn.log_softmax(out)
+                loss = R.nn.cross_entropy_with_logits(logits, label)
+                R.output(loss)
+            return loss
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 10), dtype="float32"), w0: 
R.Tensor((10, 5), dtype="float32"), b0: R.Tensor((5,), dtype="float32"), label: 
R.Tensor((3, 5), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), 
R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((5,), dtype="float32"))):
+            with R.dataflow():
+                lv0: R.Tensor((3, 5), dtype="float32") = R.matmul(x, w0, 
out_dtype="void")
+                out: R.Tensor((3, 5), dtype="float32") = R.add(lv0, b0)
+                logits: R.Tensor((3, 5), dtype="float32") = 
R.nn.log_softmax(out, axis=-1)
+                loss: R.Tensor((), dtype="float32") = 
R.nn.cross_entropy_with_logits(logits, label)
+                loss_adjoint: R.Tensor((), dtype="float32") = 
R.ones(R.shape([]), dtype="float32")
+                lv: R.Tensor((), dtype="float32") = R.divide(loss_adjoint, 
R.const(3, "float32"))
+                lv1: R.Tensor((), dtype="float32") = R.negative(lv)
+                logits_adjoint: R.Tensor((3, 5), dtype="float32") = 
R.multiply(lv1, label)
+                lv2: R.Tensor((3, 1), dtype="float32") = R.sum(logits_adjoint, 
axis=[-1], keepdims=True)
+                lv3: R.Tensor((3, 5), dtype="float32") = R.exp(logits)
+                lv4: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv3)
+                out_adjoint: R.Tensor((3, 5), dtype="float32") = 
R.subtract(logits_adjoint, lv4)
+                lv0_adjoint: R.Tensor((3, 5), dtype="float32") = out_adjoint
+                lv5: R.Tensor((10, 3), dtype="float32") = R.permute_dims(x, 
axes=[1, 0])
+                lv6: R.Tensor((10, 5), dtype="float32") = R.matmul(lv5, 
lv0_adjoint, out_dtype="void")
+                w0_adjoint: R.Tensor((10, 5), dtype="float32") = 
R.collapse_sum_to(lv6, R.shape([10, 5]))
+                b0_adjoint: R.Tensor((5,), dtype="float32") = 
R.collapse_sum_to(out_adjoint, R.shape([5]))
+                R.output(loss, w0_adjoint, b0_adjoint)
+            return (loss, (w0_adjoint, b0_adjoint))
+
+        @R.function
+        def main(x: R.Tensor((3, 10), dtype="float32"), w0: R.Tensor((10, 5), 
dtype="float32"), b0: R.Tensor((5,), dtype="float32"), label: R.Tensor((3, 5), 
dtype="float32")) -> R.Tensor((), dtype="float32"):
+            with R.dataflow():
+                lv0: R.Tensor((3, 5), dtype="float32") = R.matmul(x, w0, 
out_dtype="void")
+                out: R.Tensor((3, 5), dtype="float32") = R.add(lv0, b0)
+                logits: R.Tensor((3, 5), dtype="float32") = 
R.nn.log_softmax(out, axis=-1)
+                loss: R.Tensor((), dtype="float32") = 
R.nn.cross_entropy_with_logits(logits, label)
+                R.output(loss)
+            return loss
+    # fmt: on
+
+    After = relax.transform.Gradient("main", 
require_grads=Before["main"].params[1:3])(Before)
+    assert_structural_equal(After, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_transform_gradient_numeric.py 
b/tests/python/relax/test_transform_gradient_numeric.py
new file mode 100644
index 0000000000..7585ecf1f6
--- /dev/null
+++ b/tests/python/relax/test_transform_gradient_numeric.py
@@ -0,0 +1,192 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.relay.testing import rand
+from tvm.testing import assert_allclose
+from tvm.testing.utils import check_numerical_grads
+from tvm.script.parser import ir as I, relax as R
+from tvm.relax.transform import LegalizeOps
+
+
+def _legalize_and_build(mod, target, dev):
+    lowered_mod = LegalizeOps()(mod)
+    ex = relax.build(lowered_mod, target)
+    vm = relax.VirtualMachine(ex, dev)
+    return vm
+
+
[email protected]_targets("llvm")
+def test_manual_gradient(target, dev):
+    # The expression computed is sum((2x - 2y) * (y + z))
+    # the gradient of x is broadcast_to(2y + 2z, x.shape)
+    # the gradient of y is collapse_sum_to((2x - 4y - 2z), y.shape)
+    # the gradient of z is collapse_sum_to((2x - 2y), z.shape)
+    # the gradient of u is 0
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((3, 5), "float32"),
+            y: R.Tensor((5,), "float32"),
+            z: R.Tensor((5,), "float32"),
+            u: R.Tensor((5,), "float32"),
+        ):
+            with R.dataflow():
+                lv1 = R.add(x, x)
+                lv2 = R.subtract(lv1, y)
+                lv3 = R.subtract(lv2, y)
+                lv4 = R.add(y, z)
+                lv5 = R.multiply(lv3, lv4)
+                lv6 = R.sum(lv5)
+                R.output(lv6)
+            return lv6
+
+    After = relax.transform.Gradient("main")(Before)
+
+    args = [rand("float32", 3, 5), rand("float32", 5), rand("float32", 5), 
rand("float32", 5)]
+    args_np = [x.numpy() for x in args]
+
+    vm = _legalize_and_build(After, target, dev)
+    output, grads = vm["main_adjoint"](*args)
+    output_np = np.sum((2 * args_np[0] - 2 * args_np[1]) * (args_np[1] + 
args_np[2]))
+    assert_allclose(output.numpy(), output_np, atol=1e-4)
+
+    expected_grads_nd = [
+        (2 * args_np[1] + 2 * args_np[2]) * np.ones_like(args_np[0]),
+        np.sum((2 * args_np[0] - 4 * args_np[1] - 2 * args_np[2]), axis=0),
+        np.sum((2 * args_np[0] - 2 * args_np[1]), axis=0),
+        np.zeros_like(args_np[3]),
+    ]
+    for i, j in zip(grads, expected_grads_nd):
+        assert_allclose(i.numpy(), j, atol=1e-4)
+
+
[email protected]_targets("llvm")
+def test_mlp_blockbuilder(target, dev):
+    layers, in_size, out_size, hidden_size, batch_size = 3, 5, 5, 5, 4
+
+    input_list = [relax.Var("x", R.Tensor((batch_size, in_size), "float32"))]
+    w_list = (
+        [relax.Var("w_0", R.Tensor((in_size, hidden_size), "float32"))]
+        + [
+            relax.Var("w_" + str(i + 1), R.Tensor((hidden_size, hidden_size), 
"float32"))
+            for i in range(layers - 2)
+        ]
+        + [relax.Var("w_" + str(layers - 1), R.Tensor((hidden_size, out_size), 
"float32"))]
+    )
+    b_list = [
+        relax.Var("b_" + str(i), R.Tensor((hidden_size,), "float32")) for i in 
range(layers - 1)
+    ] + [relax.Var("b_" + str(layers - 1), R.Tensor((out_size,), "float32"))]
+    label_list = [relax.Var("y", R.Tensor((batch_size,), "int64"))]
+    args_list = input_list + w_list + b_list + label_list
+
+    bb = relax.BlockBuilder()
+    with bb.function("MLP", args_list):
+        with bb.dataflow():
+            current = input_list[0]
+            for i in range(layers):
+                lv0 = bb.emit(R.matmul(current, w_list[i]))
+                lv1 = bb.emit(R.add(lv0, b_list[i]))
+                current = bb.emit(R.nn.relu(lv1) if i < layers - 1 else lv1)
+            logits = R.nn.log_softmax(current)
+            loss = bb.emit(R.nn.nll_loss(logits, label_list[0]))
+            gv0 = bb.emit_output(loss)
+        bb.emit_func_output(gv0)
+
+    Before = bb.get()
+    After = relax.transform.Gradient("MLP", w_list + b_list)(Before)
+    # Check numerical gradients equal
+    args = []
+    for arg in After["MLP_adjoint"].params:
+        shape = [int(l) for l in arg.struct_info.shape]
+        if arg.struct_info.dtype == "int64":
+            args.append(tvm.nd.array(np.random.randint(0, out_size, 
size=shape).astype(np.int64)))
+        else:  # float32
+            args.append(rand("float32", *shape))
+
+    vm_before = _legalize_and_build(Before, target, dev)
+    vm_after = _legalize_and_build(After, target, dev)
+    _, grad = vm_after["MLP_adjoint"](*args)
+
+    def func(*inputs):
+        loss = vm_before["MLP"](args[0], *[tvm.nd.array(i) for i in inputs], 
args[-1])
+        return loss.numpy()
+
+    check_numerical_grads(func, [i.numpy() for i in args[1:-1]], [i.numpy() 
for i in grad])
+
+
[email protected]_targets("llvm")
+def test_complex(target, dev):
+    cst = relax.const(np.ones((6,)), dtype="float32")
+    cst1 = relax.const(np.array(3), dtype="int64")
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((6,), "float32"), y: R.Tensor((6, 3, 4), 
"float32")):
+            with R.dataflow():
+                lv1 = R.split(x, 2)
+                lv2 = lv1[0]
+                lv3 = lv1[1]
+                lv4 = lv2 + lv3
+                lv5 = (lv4, lv3)
+                lv6 = R.concat(lv5)
+                lv7 = (x, x)
+                lv8 = R.concat(lv7)
+                lv9 = R.concat(lv7)
+                lv10 = R.add(lv8, lv9)
+                lv11 = R.split(lv10, 2)
+                lv12 = R.add(lv6, lv11[0])
+                lv13 = cst
+                lv14 = R.add(lv12, lv13)
+                lv15 = R.subtract(lv13, lv14)
+                lv16 = R.multiply(lv14, lv15)
+                lv17 = R.multiply(lv15, lv16)
+                lv18 = R.tanh(lv17)
+                lv19 = R.sigmoid(lv18)
+                lv20 = R.permute_dims(y, axes=[0, 2, 1])
+                lv21 = R.sigmoid(lv20)
+                lv22 = R.matmul(y, lv21)
+                lv23 = R.sum(lv22, axis=[1, 2])
+                lv24 = R.add(lv19, lv23)
+                lv25 = R.nn.log_softmax(lv24)
+                gv = R.nn.nll_loss(lv25, cst1)
+                R.output(gv)
+            return gv
+
+    After = relax.transform.Gradient("main")(Before)
+    args = []
+    for arg in After["main_adjoint"].params:
+        shape = [int(l) for l in arg.struct_info.shape]
+        args.append(rand("float32", *shape))
+
+    vm_before = _legalize_and_build(Before, target, dev)
+    vm_after = _legalize_and_build(After, target, dev)
+    _, grad = vm_after["main_adjoint"](*args)
+
+    def func(*inputs):
+        loss = vm_before["main"](*[tvm.nd.array(i) for i in inputs])
+        return loss.numpy()
+
+    check_numerical_grads(func, [i.numpy() for i in args], [i.numpy() for i in 
grad])
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to