This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e7d12be07a [Unity][Training] Support intermediate vars as 
require_grads for Gradient pass (#16011)
e7d12be07a is described below

commit e7d12be07aba6e6dedade041d47146df33531fa4
Author: Yixin Dong <[email protected]>
AuthorDate: Mon Nov 6 08:37:29 2023 -0800

    [Unity][Training] Support intermediate vars as require_grads for Gradient 
pass (#16011)
---
 python/tvm/relax/op/_op_gradient.py                |   8 +-
 python/tvm/relax/op/manipulate.py                  |   2 +-
 src/relax/transform/gradient.cc                    | 111 +++--
 tests/python/relax/test_transform_gradient.py      |  58 +++
 .../relax/test_transform_gradient_checkpoint.py    | 548 ++++++++++++---------
 5 files changed, 434 insertions(+), 293 deletions(-)

diff --git a/python/tvm/relax/op/_op_gradient.py 
b/python/tvm/relax/op/_op_gradient.py
index 2873c70ba7..1b0ebfd5e4 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -256,8 +256,8 @@ def maximum_grad(
     y = orig_call.args[1]
     zero = relax.const(0, _get_dtype(x))
     return [
-        where(less(x, y), zero, output_grad),
-        where(greater_equal(x, y), zero, output_grad),
+        _fit_shape(ctx, where(less(x, y), zero, output_grad), x),
+        _fit_shape(ctx, where(greater_equal(x, y), zero, output_grad), y),
     ]
 
 
@@ -280,8 +280,8 @@ def minimum_grad(
     y = orig_call.args[1]
     zero = relax.const(0, _get_dtype(x))
     return [
-        where(greater_equal(x, y), zero, output_grad),
-        where(less(x, y), zero, output_grad),
+        _fit_shape(ctx, where(greater_equal(x, y), zero, output_grad), x),
+        _fit_shape(ctx, where(less(x, y), zero, output_grad), y),
     ]
 
 
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index 3dfa371e42..9bd99020e9 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -170,7 +170,7 @@ def permute_dims(x: Expr, axes: Optional[List[int]] = None) 
-> Expr:
         The input data to the operator.
 
     axes : Optional[List[int]]
-        The target axes order, reverse order if not specified.
+        The target axes order. If not specified, permute_dims will reverse the 
order of all axes.
 
     Returns
     -------
diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
index 8560db6a04..2bf858c98d 100644
--- a/src/relax/transform/gradient.cc
+++ b/src/relax/transform/gradient.cc
@@ -98,18 +98,22 @@ class CheckpointCollector : private ExprMutator {
    * and the end_checkpoint bindings.
    *
    * \param func The original function
-   * \return The function with all start_checkpoint and end_checkpoint 
bindings removed, and a
-   * VarIdSet containing all checkpointed vars.
+   * \return The function with all start_checkpoint and end_checkpoint 
bindings removed.
    */
-  static std::pair<Function, VarIdSet> Collect(const Function& func) {
+  Function Transform(const Function& func) {
     auto collector = CheckpointCollector();
-    return std::make_pair(Downcast<Function>(collector.VisitExpr(func)), 
collector.checkpoints_);
+    return Downcast<Function>(this->VisitExpr(func));
   }
 
+  // checkpointed vars
+  VarIdSet checkpoints;
+  // mapping from vars that are wrapped in start_checkpoint or end_checkpoint 
to the original vars
+  std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual> var_mapping;
+
  private:
   Expr VisitExpr_(const FunctionNode* func) final {
     for (auto var : func->params) {
-      checkpoints_.insert(var->vid);
+      checkpoints.insert(var->vid);
     }
 
     return ExprMutator::VisitExpr_(func);
@@ -119,7 +123,7 @@ class CheckpointCollector : private ExprMutator {
     static const auto s_cp = Op::Get("relax.grad.start_checkpoint");
     static const auto e_cp = Op::Get("relax.grad.end_checkpoint");
 
-    // If every variable that the variable of binding relys on is either
+    // If every variable that the variable of binding relies on is either
     // 1) the output of end_checkpoint; 2) checkpointed
     // then the variable of binding will be checkpointed
     auto var_binding = binding.as<VarBindingNode>();
@@ -131,12 +135,12 @@ class CheckpointCollector : private ExprMutator {
       PostOrderVisit(var_binding->value, [this, 
&all_inner_var_checkpointed](const Expr& expr) {
         if (auto var = expr.as<VarNode>()) {
           all_inner_var_checkpointed &=
-              (checkpoints_.count(var->vid) != 0 || e_vars_.count(var->vid) != 
0);
+              (checkpoints.count(var->vid) != 0 || e_vars_.count(var->vid) != 
0);
         }
       });
 
       if (all_inner_var_checkpointed) {
-        checkpoints_.insert(var_binding->var->vid);
+        checkpoints.insert(var_binding->var->vid);
       }
     }
 
@@ -162,10 +166,11 @@ class CheckpointCollector : private ExprMutator {
       } else {
         this->var_remap_[binding->var->vid] = orig_var;
       }
+      var_mapping[binding->var->vid] = orig_var;
 
       if (value->op == s_cp) {
         // mark the original var to be checkpointed
-        checkpoints_.insert(orig_var->vid);
+        checkpoints.insert(orig_var->vid);
       } else if (value->op == e_cp) {
         e_vars_.insert(binding->var->vid);
       }
@@ -174,7 +179,7 @@ class CheckpointCollector : private ExprMutator {
     }
   }
 
-  VarIdSet checkpoints_;
+  // vars that are the output of end_checkpoint
   VarIdSet e_vars_;
 };
 
@@ -233,6 +238,9 @@ class CheckpointGenerator : private ExprMutator {
   // Visit the use-site of a defined Var
   Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef<Var>(op)); 
}
 
+  // Visit the use-site of a defined DataflowVar
+  Expr VisitExpr_(const DataflowVarNode* op) final { return 
VisitVar(GetRef<Var>(op)); }
+
   Expr VisitVar(const Var& var) {
     auto it = checkpoint_map_.find(var);
     if (it != checkpoint_map_.end()) {
@@ -286,9 +294,10 @@ class BackwardBindingGenerator : private ExprVisitor {
   static Expr Generate(const BlockBuilder& builder, const DataflowBlock& 
forward_block,
                        const Array<Var>& require_grads, const Var& target_var,
                        const Array<Var>& orig_params, const Expr& 
orig_return_value,
-                       const VarIdSet& checkpoints) {
-    CheckpointGenerator checkpoint_generator(builder, orig_params, 
forward_block, checkpoints);
-    BackwardBindingGenerator generator(builder, checkpoint_generator);
+                       const CheckpointCollector& cp_collector) {
+    CheckpointGenerator checkpoint_generator(builder, orig_params, 
forward_block,
+                                             cp_collector.checkpoints);
+    BackwardBindingGenerator generator(builder, cp_collector, 
checkpoint_generator);
 
     // Initialize the adjoint of target_var as ones op. We have already 
checked the target.
     auto* target_sinfo = GetStructInfoAs<TensorStructInfoNode>(target_var);
@@ -304,8 +313,11 @@ class BackwardBindingGenerator : private ExprVisitor {
 
  private:
   explicit BackwardBindingGenerator(const BlockBuilder& builder,
+                                    const CheckpointCollector& cp_collector,
                                     const CheckpointGenerator& 
checkpoint_generator)
-      : builder_(builder), checkpoint_generator_(checkpoint_generator) {}
+      : builder_(builder),
+        cp_collector_(cp_collector),
+        checkpoint_generator_(checkpoint_generator) {}
 
   void VisitBinding(const Binding& binding) final {
     // TODO(chaofan, yixin): support other types of bindings
@@ -469,6 +481,11 @@ class BackwardBindingGenerator : private ExprVisitor {
     Array<Expr> out_adjoints;
 
     for (Var var : require_grads) {
+      // var might be wrapped in start_checkpoint or end_checkpoint, so we 
should find the original
+      // var first
+      if (cp_collector_.var_mapping.count(var->vid)) {
+        var = cp_collector_.var_mapping[var->vid];
+      }
       // If the var don't have adjoint var, it do not contribute to the 
target. So its adjoint is
       // zeros
       auto it = adjoint_var_map_.find(var);
@@ -575,6 +592,8 @@ class BackwardBindingGenerator : private ExprVisitor {
   BlockBuilder builder_;
   // Forward Var to its adjoint Var
   Map<Var, Var> adjoint_var_map_;
+  // information collected by CheckpointCollector
+  CheckpointCollector cp_collector_;
   // The generator for checkpoint bindings
   CheckpointGenerator checkpoint_generator_;
 };
@@ -586,48 +605,42 @@ class GradientMutator : private ExprMutator {
     // Step 1. Copy function
     auto* old_func = mod->Lookup(func_name).as<FunctionNode>();
     CHECK(old_func) << func_name << "is not a Relax Function";
-    auto new_func = CopyWithNewVars(GetRef<Function>(old_func));
+    auto copier = FunctionCopier();
+    auto new_func = copier.Copy(GetRef<Function>(old_func));
 
     // Step 2. Handle the checkpoints and eliminate start_checkpoint and 
end_checkpoint ops
-    auto checkpoint_collected = CheckpointCollector::Collect(new_func);
-    new_func = checkpoint_collected.first;
-    auto checkpoints = checkpoint_collected.second;
-
-    // Step 3. Collect call_tir_with_grad information
-    auto tir_grad_collected = CallTIRWithGradEliminator::Transform(new_func);
+    auto cp_collector = CheckpointCollector();
+    new_func = cp_collector.Transform(new_func);
 
-    // Step 4. Handle require_grads
+    // Step 3. Handle require_grads
     // When require_grads is not specified, it would be set to all params of 
the function
-    if (require_grads) {
-      CheckRequireGrads(require_grads.value(), old_func->params, func_name);
+    if (!require_grads) {
+      require_grads = new_func->params;
+    } else {
+      require_grads = CheckAndMapRequireGrads(require_grads.value(), 
copier.GetVarMap(), func_name);
     }
-    // then map the parameter list into new params
-    auto require_grads_value = 
require_grads.value_or(old_func->params).Map([&](const Var& v) {
-      return new_func->params[std::find(old_func->params.begin(), 
old_func->params.end(), v) -
-                              old_func->params.begin()];
-    });
 
-    // Step 5. Generate the adjoint function, use RemoveAllUnused to simplify 
it, and then return
+    // Step 4. Generate the adjoint function, use RemoveAllUnused to simplify 
it, and then return
     // the IRModule with the adjoint function
-    return GradientMutator(mod, require_grads_value, target_index, checkpoints)
+    return GradientMutator(mod, require_grads.value(), target_index, 
cp_collector)
         .AddAdjointFunction(new_func, func_name, true);
   }
 
  private:
   GradientMutator(const IRModule& module, const Array<Var>& require_grads, int 
target_index,
-                  const VarIdSet& checkpoints)
+                  const CheckpointCollector& cp_collector)
       : ExprMutator(module),
         require_grads_(require_grads),
-        checkpoints_(checkpoints),
+        cp_collector_(cp_collector),
         target_index_(target_index) {}
 
   // Add the adjoint function of func to the IRModule using BlockBuilder
   IRModule AddAdjointFunction(const Function& func, const String& func_name,
                               bool remove_all_unused = true) {
-    // Step 5.1 forward -> forward + backward
+    // Step 4.1 forward -> forward + backward
     auto new_func = Downcast<Function>(VisitExpr(func));
 
-    // Step 5.2 Convert call_tir_with_grad nodes into call_tir nodes
+    // Step 4.2 Convert call_tir_with_grad nodes into call_tir nodes
     // because call_tir_with_grad nodes is not actually implemented
     new_func = CallTIRWithGradEliminator::Transform(new_func);
 
@@ -635,12 +648,12 @@ class GradientMutator : private ExprMutator {
       new_func = Downcast<Function>(RemoveAllUnused(new_func));
     }
 
-    // Step 5.3 mark the transformed function as public
+    // Step 4.3 mark the transformed function as public
     // because the original function may be public, and have gsymbol attribute 
as func_name
     auto new_func_name = func_name + "_adjoint";
     auto new_func_with_gsymbol = WithAttr(new_func, tvm::attr::kGlobalSymbol, 
new_func_name);
 
-    // Step 5.4 Add the transformed function to IRModule
+    // Step 4.4 Add the transformed function to IRModule
     builder_->AddFunction(new_func_with_gsymbol, new_func_name);
     return builder_->GetContextIRModule();
   }
@@ -679,7 +692,7 @@ class GradientMutator : private ExprMutator {
     // generate backward bindings and the return value
     return_expr_ = BackwardBindingGenerator::Generate(builder_, 
GetRef<DataflowBlock>(block),
                                                       require_grads_, 
target_var_, orig_params_,
-                                                      orig_return_expr_, 
checkpoints_);
+                                                      orig_return_expr_, 
cp_collector_);
 
     return builder_->EndBlock();
   }
@@ -700,7 +713,8 @@ class GradientMutator : private ExprMutator {
       target_var_ = GetRef<Var>(var);
     } else if (auto* tuple = e.as<TupleNode>()) {
       CHECK(target_index >= 0 && target_index < 
static_cast<int>(tuple->fields.size()))
-          << "target_index should be in the range of the number of return 
values of the function. "
+          << "target_index should be in the range of the number of return 
values of the "
+             "function. "
              "But the specified target_index is "
           << target_index << ", while the number of return values is " << 
tuple->fields.size();
       auto* var = tuple->fields[target_index].as<VarNode>();
@@ -721,30 +735,33 @@ class GradientMutator : private ExprMutator {
 
   // Check every Var in require_grads:
   // 1. there should be no duplicate var
-  // 2. every var should be a parameter of the function
+  // 2. every var should be a parameter or a intermediate var in the function
   // 3. the type of the input var should be Tensor of floating point dtype, or 
Tuple of that
-  static void CheckRequireGrads(const Array<Var>& require_grads, const 
Array<Var>& func_params,
-                                const String& func_name) {
+  static Array<Var> CheckAndMapRequireGrads(const Array<Var>& require_grads,
+                                            const Map<Var, Var>& var_map, 
const String& func_name) {
     VarIdSet var_set;
+    Array<Var> mapped_vars;
     for (const auto& var : require_grads) {
-      CHECK(std::find(func_params.begin(), func_params.end(), var) != 
func_params.end())
-          << "There is no Var named " << var->name_hint() << " in the 
parameters of the function "
-          << func_name;
+      auto it = var_map.find(var);
+      CHECK(it != var_map.end()) << "There is no Var named " << 
var->name_hint()
+                                 << " in the function " << func_name;
       CHECK_EQ(var_set.count(var->vid), 0)
           << "Var " << var->name_hint() << " appears more than once";
       var_set.emplace(var->vid);
+      mapped_vars.push_back((*it).second);
 
       CHECK(IsNestedTensorConditioned(GetStructInfo(var), IsFloatTensorSInfo))
           << "Only Tensors of floating point dtype or Tuples of float "
              "Tensors can require gradients, but the StructInfo of Var "
           << var->name_hint() << " is " << GetStructInfo(var);
     }
+    return mapped_vars;
   }
 
   // differentiation sources
   Array<Var> require_grads_;
-  // checkpoint
-  VarIdSet checkpoints_;
+  // information collected by CheckpointCollector
+  CheckpointCollector cp_collector_;
   // the differentiation target
   int target_index_;
   Var target_var_;
diff --git a/tests/python/relax/test_transform_gradient.py 
b/tests/python/relax/test_transform_gradient.py
index b96932f8c5..c4e2f9d526 100644
--- a/tests/python/relax/test_transform_gradient.py
+++ b/tests/python/relax/test_transform_gradient.py
@@ -309,6 +309,64 @@ def test_target_index():
     assert_structural_equal(After, Expected)
 
 
+def test_intermediate_var_require_grads():
+    x = relax.Var("x", R.Tensor((3, 3), "float32"))
+    y = relax.Var("y", R.Tensor((3, 3), "float32"))
+
+    bb = relax.BlockBuilder()
+    with bb.function("main", [x, y]):
+        with bb.dataflow():
+            lv0 = bb.emit(x * x)
+            lv1 = bb.emit(lv0 * y)
+            lv2 = bb.emit(lv1 * y)
+            gv0 = bb.emit_output(relax.op.sum(lv2))
+        bb.emit_func_output(gv0)
+
+    Before = bb.get()
+
+    # fmt: off
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 
3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), 
R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"), 
R.Tensor((), dtype="float32"))):
+            with R.dataflow():
+                lv: R.Tensor((3, 3), dtype="float32") = R.multiply(x, x)
+                lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(lv, y)
+                lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, y)
+                gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), dtype="float32") = 
R.ones(R.shape([]), dtype="float32")
+                lv2_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+                lv1_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv2_adjoint, y)
+                lv_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv1_adjoint, y)
+                x_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv_adjoint, x)
+                lv1_1: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv_adjoint, x)
+                x_adjoint1: R.Tensor((3, 3), dtype="float32") = 
R.add(x_adjoint, lv1_1)
+                x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint1
+                lv1_adjoint_out: R.Tensor((3, 3), dtype="float32") = 
lv1_adjoint
+                gv_adjoint_out: R.Tensor((), dtype="float32") = gv_adjoint
+                R.output(gv, x_adjoint_out, lv1_adjoint_out, gv_adjoint_out)
+            return (gv, (x_adjoint_out, lv1_adjoint_out, gv_adjoint_out))
+
+        @R.function
+        def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), 
dtype="float32")) -> R.Tensor((), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((3, 3), dtype="float32") = R.multiply(x, x)
+                lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(lv, y)
+                lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, y)
+                gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+    # fmt: on
+
+    After = relax.transform.Gradient("main", [x, lv1, gv0])(Before)
+    assert_structural_equal(After, Expected)
+
+    # z does not occur in function
+    z = relax.Var("z", R.Tensor((3, 3), "float32"))
+    with pytest.raises(TVMError):
+        relax.transform.Gradient("main", [x, lv1, z])(Before)
+
+
 def test_tuple():
     # fmt: off
     @I.ir_module
diff --git a/tests/python/relax/test_transform_gradient_checkpoint.py 
b/tests/python/relax/test_transform_gradient_checkpoint.py
index 3e94125b77..9de62d341c 100644
--- a/tests/python/relax/test_transform_gradient_checkpoint.py
+++ b/tests/python/relax/test_transform_gradient_checkpoint.py
@@ -16,12 +16,13 @@
 # under the License.
 """Unit tests for gradient with checkpointing."""
 import tvm
-from tvm.relax.block_builder import BlockBuilder
-from tvm.relax.testing.nn import checkpoint, emit_checkpoint, 
emit_checkpoint_sequential
 import tvm.testing
+
 from tvm import relax
 from tvm.ir.base import assert_structural_equal
-from tvm.script.parser import relax as R, ir as I
+from tvm.relax.block_builder import BlockBuilder
+from tvm.relax.testing import nn
+from tvm.script.parser import ir as I, relax as R
 
 
 def test_sequential():
@@ -47,51 +48,51 @@ def test_sequential():
     @I.ir_module
     class Expected:
         @R.function
-        def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) -> 
R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), 
dtype="float32"))):
+        def main_adjoint(x: R.Tensor((3, 3), "float32")) -> 
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
             with R.dataflow():
-                lv1: R.Tensor((3, 3), dtype="float32") = R.power(x, R.const(3, 
"float32"))
-                lv2: R.Tensor((3, 3), dtype="float32") = R.power(lv1, 
R.const(3, "float32"))
-                lv3: R.Tensor((3, 3), dtype="float32") = R.power(lv2, 
R.const(3, "float32"))
-                lv4: R.Tensor((3, 3), dtype="float32") = R.power(lv3, 
R.const(3, "float32"))
-                gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None, 
keepdims=False)
-                gv_1: R.Tensor((), dtype="float32") = gv
-                gv_adjoint: R.Tensor((), dtype="float32") = 
R.ones(R.shape([]), dtype="float32")
-                gv_adjoint1: R.Tensor((), dtype="float32") = gv_adjoint
-                lv3_cp: R.Tensor((3, 3), dtype="float32") = R.power(lv2, 
R.const(3, "float32"))
-                lv4_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.broadcast_to(gv_adjoint1, R.shape([3, 3]))
-                lv: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv4_adjoint, R.const(3, "float32"))
-                lv1_1: R.Tensor((), dtype="float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
-                lv2_1: R.Tensor((3, 3), dtype="float32") = R.power(lv3_cp, 
lv1_1)
-                lv3_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv, lv2_1)
-                lv6: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv3_adjoint, R.const(3, "float32"))
-                lv7: R.Tensor((), dtype="float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
-                lv8: R.Tensor((3, 3), dtype="float32") = R.power(lv2, lv7)
-                lv2_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv6, lv8)
-                lv1_cp: R.Tensor((3, 3), dtype="float32") = R.power(x, 
R.const(3, "float32"))
-                lv12: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv2_adjoint, R.const(3, "float32"))
-                lv13: R.Tensor((), dtype="float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
-                lv14: R.Tensor((3, 3), dtype="float32") = R.power(lv1_cp, lv13)
-                lv1_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv12, lv14)
-                lv18: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv1_adjoint, R.const(3, "float32"))
-                lv19: R.Tensor((), dtype="float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
-                lv20: R.Tensor((3, 3), dtype="float32") = R.power(x, lv19)
-                x_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv18, lv20)
-                x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint
+                lv1: R.Tensor((3, 3), "float32") = R.power(x, R.const(3, 
"float32"))
+                lv2: R.Tensor((3, 3), "float32") = R.power(lv1, R.const(3, 
"float32"))
+                lv3: R.Tensor((3, 3), "float32") = R.power(lv2, R.const(3, 
"float32"))
+                lv4: R.Tensor((3, 3), "float32") = R.power(lv3, R.const(3, 
"float32"))
+                gv: R.Tensor((), "float32") = R.sum(lv4, axis=None, 
keepdims=False)
+                gv_1: R.Tensor((), "float32") = gv
+                gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]), 
"float32")
+                gv_adjoint1: R.Tensor((), "float32") = gv_adjoint
+                lv3_cp: R.Tensor((3, 3), "float32") = R.power(lv2, R.const(3, 
"float32"))
+                lv4_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint1, R.shape([3, 3]))
+                lv: R.Tensor((3, 3), "float32") = R.multiply(lv4_adjoint, 
R.const(3, "float32"))
+                lv1_1: R.Tensor((), "float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
+                lv2_1: R.Tensor((3, 3), "float32") = R.power(lv3_cp, lv1_1)
+                lv3_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv, 
lv2_1)
+                lv6: R.Tensor((3, 3), "float32") = R.multiply(lv3_adjoint, 
R.const(3, "float32"))
+                lv7: R.Tensor((), "float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
+                lv8: R.Tensor((3, 3), "float32") = R.power(lv2, lv7)
+                lv2_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv6, lv8)
+                lv1_cp: R.Tensor((3, 3), "float32") = R.power(x, R.const(3, 
"float32"))
+                lv12: R.Tensor((3, 3), "float32") = R.multiply(lv2_adjoint, 
R.const(3, "float32"))
+                lv13: R.Tensor((), "float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
+                lv14: R.Tensor((3, 3), "float32") = R.power(lv1_cp, lv13)
+                lv1_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv12, 
lv14)
+                lv18: R.Tensor((3, 3), "float32") = R.multiply(lv1_adjoint, 
R.const(3, "float32"))
+                lv19: R.Tensor((), "float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
+                lv20: R.Tensor((3, 3), "float32") = R.power(x, lv19)
+                x_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv18, lv20)
+                x_adjoint_out: R.Tensor((3, 3), "float32") = x_adjoint
                 R.output(gv_1, x_adjoint_out)
             return (gv_1, (x_adjoint_out,))
 
         @R.function
-        def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), 
dtype="float32"):
+        def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
             with R.dataflow():
-                x_scp: R.Tensor((3, 3), dtype="float32") = 
R.grad.start_checkpoint(x)
-                lv1: R.Tensor((3, 3), dtype="float32") = R.power(x_scp, 
R.const(3, "float32"))
-                lv1_ecp: R.Tensor((3, 3), dtype="float32") = 
R.grad.end_checkpoint(lv1)
-                lv2: R.Tensor((3, 3), dtype="float32") = R.power(lv1_ecp, 
R.const(3, "float32"))
-                lv2_scp: R.Tensor((3, 3), dtype="float32") = 
R.grad.start_checkpoint(lv2)
-                lv3: R.Tensor((3, 3), dtype="float32") = R.power(lv2_scp, 
R.const(3, "float32"))
-                lv4: R.Tensor((3, 3), dtype="float32") = R.power(lv3, 
R.const(3, "float32"))
-                gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None, 
keepdims=False)
-                gv_ecp: R.Tensor((), dtype="float32") = 
R.grad.end_checkpoint(gv)
+                x_scp: R.Tensor((3, 3), "float32") = R.grad.start_checkpoint(x)
+                lv1: R.Tensor((3, 3), "float32") = R.power(x_scp, R.const(3, 
"float32"))
+                lv1_ecp: R.Tensor((3, 3), "float32") = 
R.grad.end_checkpoint(lv1)
+                lv2: R.Tensor((3, 3), "float32") = R.power(lv1_ecp, R.const(3, 
"float32"))
+                lv2_scp: R.Tensor((3, 3), "float32") = 
R.grad.start_checkpoint(lv2)
+                lv3: R.Tensor((3, 3), "float32") = R.power(lv2_scp, R.const(3, 
"float32"))
+                lv4: R.Tensor((3, 3), "float32") = R.power(lv3, R.const(3, 
"float32"))
+                gv: R.Tensor((), "float32") = R.sum(lv4, axis=None, 
keepdims=False)
+                gv_ecp: R.Tensor((), "float32") = R.grad.end_checkpoint(gv)
                 R.output(gv_ecp)
             return gv_ecp
     # fmt: on
@@ -123,49 +124,49 @@ def test_sequential_consecutive():
     @I.ir_module
     class Expected:
         @R.function
-        def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) -> 
R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), 
dtype="float32"))):
+        def main_adjoint(x: R.Tensor((3, 3), "float32")) -> 
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
             with R.dataflow():
-                lv1: R.Tensor((3, 3), dtype="float32") = R.power(x, R.const(3, 
"float32"))
-                lv2: R.Tensor((3, 3), dtype="float32") = R.power(lv1, 
R.const(3, "float32"))
-                lv3: R.Tensor((3, 3), dtype="float32") = R.power(lv2, 
R.const(3, "float32"))
-                lv4: R.Tensor((3, 3), dtype="float32") = R.power(lv3, 
R.const(3, "float32"))
-                gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None, 
keepdims=False)
-                gv_adjoint: R.Tensor((), dtype="float32") = 
R.ones(R.shape([]), dtype="float32")
-                lv3_cp: R.Tensor((3, 3), dtype="float32") = R.power(lv2, 
R.const(3, "float32"))
-                lv4_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
-                lv: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv4_adjoint, R.const(3, "float32"))
-                lv1_1: R.Tensor((), dtype="float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
-                lv2_1: R.Tensor((3, 3), dtype="float32") = R.power(lv3_cp, 
lv1_1)
-                lv3_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv, lv2_1)
-                lv6: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv3_adjoint, R.const(3, "float32"))
-                lv7: R.Tensor((), dtype="float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
-                lv8: R.Tensor((3, 3), dtype="float32") = R.power(lv2, lv7)
-                lv2_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv6, lv8)
-                lv1_cp: R.Tensor((3, 3), dtype="float32") = R.power(x, 
R.const(3, "float32"))
-                lv12: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv2_adjoint, R.const(3, "float32"))
-                lv13: R.Tensor((), dtype="float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
-                lv14: R.Tensor((3, 3), dtype="float32") = R.power(lv1_cp, lv13)
-                lv1_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv12, lv14)
-                lv18: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv1_adjoint, R.const(3, "float32"))
-                lv19: R.Tensor((), dtype="float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
-                lv20: R.Tensor((3, 3), dtype="float32") = R.power(x, lv19)
-                x_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv18, lv20)
-                x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint
+                lv1: R.Tensor((3, 3), "float32") = R.power(x, R.const(3, 
"float32"))
+                lv2: R.Tensor((3, 3), "float32") = R.power(lv1, R.const(3, 
"float32"))
+                lv3: R.Tensor((3, 3), "float32") = R.power(lv2, R.const(3, 
"float32"))
+                lv4: R.Tensor((3, 3), "float32") = R.power(lv3, R.const(3, 
"float32"))
+                gv: R.Tensor((), "float32") = R.sum(lv4, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]), 
"float32")
+                lv3_cp: R.Tensor((3, 3), "float32") = R.power(lv2, R.const(3, 
"float32"))
+                lv4_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+                lv: R.Tensor((3, 3), "float32") = R.multiply(lv4_adjoint, 
R.const(3, "float32"))
+                lv1_1: R.Tensor((), "float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
+                lv2_1: R.Tensor((3, 3), "float32") = R.power(lv3_cp, lv1_1)
+                lv3_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv, 
lv2_1)
+                lv6: R.Tensor((3, 3), "float32") = R.multiply(lv3_adjoint, 
R.const(3, "float32"))
+                lv7: R.Tensor((), "float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
+                lv8: R.Tensor((3, 3), "float32") = R.power(lv2, lv7)
+                lv2_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv6, lv8)
+                lv1_cp: R.Tensor((3, 3), "float32") = R.power(x, R.const(3, 
"float32"))
+                lv12: R.Tensor((3, 3), "float32") = R.multiply(lv2_adjoint, 
R.const(3, "float32"))
+                lv13: R.Tensor((), "float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
+                lv14: R.Tensor((3, 3), "float32") = R.power(lv1_cp, lv13)
+                lv1_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv12, 
lv14)
+                lv18: R.Tensor((3, 3), "float32") = R.multiply(lv1_adjoint, 
R.const(3, "float32"))
+                lv19: R.Tensor((), "float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
+                lv20: R.Tensor((3, 3), "float32") = R.power(x, lv19)
+                x_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv18, lv20)
+                x_adjoint_out: R.Tensor((3, 3), "float32") = x_adjoint
                 R.output(gv, x_adjoint_out)
             return (gv, (x_adjoint_out,))
 
         @R.function
-        def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), 
dtype="float32"):
+        def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
             with R.dataflow():
-                x_scp: R.Tensor((3, 3), dtype="float32") = 
R.grad.start_checkpoint(x)
-                lv1: R.Tensor((3, 3), dtype="float32") = R.power(x_scp, 
R.const(3, "float32"))
-                lv2: R.Tensor((3, 3), dtype="float32") = R.power(lv1, 
R.const(3, "float32"))
-                lv2_ecp: R.Tensor((3, 3), dtype="float32") = 
R.grad.end_checkpoint(lv2)
-                lv2_scp: R.Tensor((3, 3), dtype="float32") = 
R.grad.start_checkpoint(lv2_ecp)
-                lv3: R.Tensor((3, 3), dtype="float32") = R.power(lv2_scp, 
R.const(3, "float32"))
-                lv4: R.Tensor((3, 3), dtype="float32") = R.power(lv3, 
R.const(3, "float32"))
-                lv4_ecp: R.Tensor((3, 3), dtype="float32") = 
R.grad.end_checkpoint(lv4)
-                gv: R.Tensor((), dtype="float32") = R.sum(lv4_ecp, axis=None, 
keepdims=False)
+                x_scp: R.Tensor((3, 3), "float32") = R.grad.start_checkpoint(x)
+                lv1: R.Tensor((3, 3), "float32") = R.power(x_scp, R.const(3, 
"float32"))
+                lv2: R.Tensor((3, 3), "float32") = R.power(lv1, R.const(3, 
"float32"))
+                lv2_ecp: R.Tensor((3, 3), "float32") = 
R.grad.end_checkpoint(lv2)
+                lv2_scp: R.Tensor((3, 3), "float32") = 
R.grad.start_checkpoint(lv2_ecp)
+                lv3: R.Tensor((3, 3), "float32") = R.power(lv2_scp, R.const(3, 
"float32"))
+                lv4: R.Tensor((3, 3), "float32") = R.power(lv3, R.const(3, 
"float32"))
+                lv4_ecp: R.Tensor((3, 3), "float32") = 
R.grad.end_checkpoint(lv4)
+                gv: R.Tensor((), "float32") = R.sum(lv4_ecp, axis=None, 
keepdims=False)
                 R.output(gv)
             return gv
 
@@ -196,49 +197,49 @@ def test_tuple():
     @I.ir_module
     class Expected:
         @R.function
-        def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) -> 
R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), 
dtype="float32"))):
+        def main_adjoint(x: R.Tensor((3, 3), "float32")) -> 
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
             with R.dataflow():
-                lv1: R.Tensor((3, 3), dtype="float32") = R.power(x, R.const(3, 
"float32"))
-                lv2: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 
3), dtype="float32")) = x, lv1
-                lv3: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 
3), dtype="float32")) = lv2
-                lv4: R.Tensor((3, 3), dtype="float32") = lv3[0]
-                lv4_1: R.Tensor((3, 3), dtype="float32") = R.power(lv4, 
R.const(3, "float32"))
-                gv: R.Tensor((), dtype="float32") = R.sum(lv4_1, axis=None, 
keepdims=False)
-                gv_adjoint: R.Tensor((), dtype="float32") = 
R.ones(R.shape([]), dtype="float32")
-                lv1_cp: R.Tensor((3, 3), dtype="float32") = R.power(x, 
R.const(3, "float32"))
-                lv2_cp: R.Tuple(R.Tensor((3, 3), dtype="float32"), 
R.Tensor((3, 3), dtype="float32")) = x, lv1_cp
-                lv3_cp: R.Tuple(R.Tensor((3, 3), dtype="float32"), 
R.Tensor((3, 3), dtype="float32")) = lv2_cp
-                lv4_cp: R.Tensor((3, 3), dtype="float32") = lv3_cp[0]
-                lv4_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
-                lv: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv4_adjoint, R.const(3, "float32"))
-                lv1_1: R.Tensor((), dtype="float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
-                lv2_1: R.Tensor((3, 3), dtype="float32") = R.power(lv4_cp, 
lv1_1)
-                lv4_adjoint1: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv, lv2_1)
-                lv6: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3, 
3]), dtype="float32")
-                lv3_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), 
R.Tensor((3, 3), dtype="float32")) = lv4_adjoint1, lv6
-                lv2_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"), 
R.Tensor((3, 3), dtype="float32")) = lv3_adjoint
-                x_adjoint: R.Tensor((3, 3), dtype="float32") = lv2_adjoint[0]
-                lv1_adjoint: R.Tensor((3, 3), dtype="float32") = lv2_adjoint[1]
-                lv7: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv1_adjoint, R.const(3, "float32"))
-                lv8: R.Tensor((), dtype="float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
-                lv9: R.Tensor((3, 3), dtype="float32") = R.power(x, lv8)
-                lv12: R.Tensor((3, 3), dtype="float32") = R.multiply(lv7, lv9)
-                x_adjoint1: R.Tensor((3, 3), dtype="float32") = 
R.add(x_adjoint, lv12)
-                x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint1
+                lv1: R.Tensor((3, 3), "float32") = R.power(x, R.const(3, 
"float32"))
+                lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = x, lv1
+                lv3: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = lv2
+                lv4: R.Tensor((3, 3), "float32") = lv3[0]
+                lv4_1: R.Tensor((3, 3), "float32") = R.power(lv4, R.const(3, 
"float32"))
+                gv: R.Tensor((), "float32") = R.sum(lv4_1, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]), 
"float32")
+                lv1_cp: R.Tensor((3, 3), "float32") = R.power(x, R.const(3, 
"float32"))
+                lv2_cp: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = x, lv1_cp
+                lv3_cp: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = lv2_cp
+                lv4_cp: R.Tensor((3, 3), "float32") = lv3_cp[0]
+                lv4_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+                lv: R.Tensor((3, 3), "float32") = R.multiply(lv4_adjoint, 
R.const(3, "float32"))
+                lv1_1: R.Tensor((), "float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
+                lv2_1: R.Tensor((3, 3), "float32") = R.power(lv4_cp, lv1_1)
+                lv4_adjoint1: R.Tensor((3, 3), "float32") = R.multiply(lv, 
lv2_1)
+                lv6: R.Tensor((3, 3), "float32") = R.zeros(R.shape([3, 3]), 
"float32")
+                lv3_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")) = lv4_adjoint1, lv6
+                lv2_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32")) = lv3_adjoint
+                x_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint[0]
+                lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint[1]
+                lv7: R.Tensor((3, 3), "float32") = R.multiply(lv1_adjoint, 
R.const(3, "float32"))
+                lv8: R.Tensor((), "float32") = R.subtract(R.const(3, 
"float32"), R.const(1, "float32"))
+                lv9: R.Tensor((3, 3), "float32") = R.power(x, lv8)
+                lv12: R.Tensor((3, 3), "float32") = R.multiply(lv7, lv9)
+                x_adjoint1: R.Tensor((3, 3), "float32") = R.add(x_adjoint, 
lv12)
+                x_adjoint_out: R.Tensor((3, 3), "float32") = x_adjoint1
                 R.output(gv, x_adjoint_out)
             return (gv, (x_adjoint_out,))
 
         @R.function
-        def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), 
dtype="float32"):
+        def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
             with R.dataflow():
-                x_scp: R.Tensor((3, 3), dtype="float32") = 
R.grad.start_checkpoint(x)
-                lv1: R.Tensor((3, 3), dtype="float32") = R.power(x_scp, 
R.const(3, "float32"))
-                lv2: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 
3), dtype="float32")) = x, lv1
-                lv3: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 
3), dtype="float32")) = lv2
-                lv4: R.Tensor((3, 3), dtype="float32") = lv3[0]
-                lv4_1: R.Tensor((3, 3), dtype="float32") = R.power(lv4, 
R.const(3, "float32"))
-                lv4_ecp: R.Tensor((3, 3), dtype="float32") = 
R.grad.end_checkpoint(lv4_1)
-                gv: R.Tensor((), dtype="float32") = R.sum(lv4_ecp, axis=None, 
keepdims=False)
+                x_scp: R.Tensor((3, 3), "float32") = R.grad.start_checkpoint(x)
+                lv1: R.Tensor((3, 3), "float32") = R.power(x_scp, R.const(3, 
"float32"))
+                lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = x, lv1
+                lv3: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), 
"float32")) = lv2
+                lv4: R.Tensor((3, 3), "float32") = lv3[0]
+                lv4_1: R.Tensor((3, 3), "float32") = R.power(lv4, R.const(3, 
"float32"))
+                lv4_ecp: R.Tensor((3, 3), "float32") = 
R.grad.end_checkpoint(lv4_1)
+                gv: R.Tensor((), "float32") = R.sum(lv4_ecp, axis=None, 
keepdims=False)
                 R.output(gv)
             return gv
     # fmt: on
@@ -272,35 +273,35 @@ def test_tree():
     @I.ir_module
     class Expected1:
         @R.function
-        def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 
3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3), 
dtype="float32"), v: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), 
dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), 
dtype="float32"), R.Tensor((3, 3), dtype="float32"))):
+        def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32"), v: 
R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), 
R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3, 
3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))):
             with R.dataflow():
-                lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(x, y)
-                lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, z)
-                lv3: R.Tensor((3, 3), dtype="float32") = R.multiply(u, v)
-                lv4: R.Tensor((3, 3), dtype="float32") = R.multiply(lv2, lv3)
-                gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None, 
keepdims=False)
-                gv_adjoint: R.Tensor((), dtype="float32") = 
R.ones(R.shape([]), dtype="float32")
-                lv4_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
-                lv2_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, z)
-                lv3_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(u, v)
-                lv2_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv4_adjoint, lv3_cp)
-                lv3_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv4_adjoint, lv2_cp)
-                u_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv3_adjoint, v)
-                v_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv3_adjoint, u)
-                lv1_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv2_adjoint, z)
-                z_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv2_adjoint, lv1)
-                x_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv1_adjoint, y)
-                y_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv1_adjoint, x)
-                x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint
-                y_adjoint_out: R.Tensor((3, 3), dtype="float32") = y_adjoint
-                z_adjoint_out: R.Tensor((3, 3), dtype="float32") = z_adjoint
-                u_adjoint_out: R.Tensor((3, 3), dtype="float32") = u_adjoint
-                v_adjoint_out: R.Tensor((3, 3), dtype="float32") = v_adjoint
+                lv1: R.Tensor((3, 3), "float32") = R.multiply(x, y)
+                lv2: R.Tensor((3, 3), "float32") = R.multiply(lv1, z)
+                lv3: R.Tensor((3, 3), "float32") = R.multiply(u, v)
+                lv4: R.Tensor((3, 3), "float32") = R.multiply(lv2, lv3)
+                gv: R.Tensor((), "float32") = R.sum(lv4, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]), 
"float32")
+                lv4_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+                lv2_cp: R.Tensor((3, 3), "float32") = R.multiply(lv1, z)
+                lv3_cp: R.Tensor((3, 3), "float32") = R.multiply(u, v)
+                lv2_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv4_adjoint, lv3_cp)
+                lv3_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv4_adjoint, lv2_cp)
+                u_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv3_adjoint, v)
+                v_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv3_adjoint, u)
+                lv1_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv2_adjoint, z)
+                z_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv2_adjoint, lv1)
+                x_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv1_adjoint, y)
+                y_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv1_adjoint, x)
+                x_adjoint_out: R.Tensor((3, 3), "float32") = x_adjoint
+                y_adjoint_out: R.Tensor((3, 3), "float32") = y_adjoint
+                z_adjoint_out: R.Tensor((3, 3), "float32") = z_adjoint
+                u_adjoint_out: R.Tensor((3, 3), "float32") = u_adjoint
+                v_adjoint_out: R.Tensor((3, 3), "float32") = v_adjoint
                 R.output(gv, x_adjoint_out, y_adjoint_out, z_adjoint_out, 
u_adjoint_out, v_adjoint_out)
             return (gv, (x_adjoint_out, y_adjoint_out, z_adjoint_out, 
u_adjoint_out, v_adjoint_out))
 
         @R.function
-        def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), 
dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3), 
dtype="float32"), v: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), 
dtype="float32"):
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32"), v: 
R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
             with R.dataflow():
                 lv1 = x * y
                 lv1_scp = R.grad.start_checkpoint(lv1)
@@ -324,24 +325,24 @@ def test_tree():
     @I.ir_module
     class Expected2:
         @R.function
-        def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 
3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3), 
dtype="float32"), v: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((), 
dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"))):
+        def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32"), v: 
R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), 
R.Tuple(R.Tensor((3, 3), "float32"))):
             with R.dataflow():
-                lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(x, y)
-                lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, z)
-                lv3: R.Tensor((3, 3), dtype="float32") = R.multiply(u, v)
-                lv4: R.Tensor((3, 3), dtype="float32") = R.multiply(lv2, lv3)
-                gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None, 
keepdims=False)
-                gv_adjoint: R.Tensor((), dtype="float32") = 
R.ones(R.shape([]), dtype="float32")
-                lv4_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
-                lv3_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(u, v)
-                lv2_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv4_adjoint, lv3_cp)
-                z_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv2_adjoint, lv1)
-                z_adjoint_out: R.Tensor((3, 3), dtype="float32") = z_adjoint
+                lv1: R.Tensor((3, 3), "float32") = R.multiply(x, y)
+                lv2: R.Tensor((3, 3), "float32") = R.multiply(lv1, z)
+                lv3: R.Tensor((3, 3), "float32") = R.multiply(u, v)
+                lv4: R.Tensor((3, 3), "float32") = R.multiply(lv2, lv3)
+                gv: R.Tensor((), "float32") = R.sum(lv4, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]), 
"float32")
+                lv4_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+                lv3_cp: R.Tensor((3, 3), "float32") = R.multiply(u, v)
+                lv2_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv4_adjoint, lv3_cp)
+                z_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv2_adjoint, lv1)
+                z_adjoint_out: R.Tensor((3, 3), "float32") = z_adjoint
                 R.output(gv, z_adjoint_out)
             return (gv, (z_adjoint_out,))
 
         @R.function
-        def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3), 
dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3), 
dtype="float32"), v: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), 
dtype="float32"):
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32"), v: 
R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
             with R.dataflow():
                 lv1 = x * y
                 lv1_scp = R.grad.start_checkpoint(lv1)
@@ -387,54 +388,54 @@ def test_dag():
                 lv12 = R.multiply(lv11, R.const(2, "float32"))
                 lv13 = R.grad.end_checkpoint(lv12)
                 lv14 = R.multiply(lv9, lv13)
-                gv: R.Tensor((), dtype="float32") = R.sum(lv14, axis=None, 
keepdims=False)
+                gv: R.Tensor((), "float32") = R.sum(lv14, axis=None, 
keepdims=False)
                 R.output(gv)
             return gv
 
     @I.ir_module
     class Expected:
         @R.function
-        def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) -> 
R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3), 
dtype="float32"))):
+        def main_adjoint(x: R.Tensor((3, 3), "float32")) -> 
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
             with R.dataflow():
-                lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(x, 
R.const(2, "float32"))
-                lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, 
R.const(2, "float32"))
-                lv3: R.Tensor((3, 3), dtype="float32") = R.multiply(x, lv2)
-                lv4: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3, 
R.const(2, "float32"))
-                lv5: R.Tensor((3, 3), dtype="float32") = R.multiply(lv4, 
R.const(2, "float32"))
-                lv6: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3, lv5)
-                lv7: R.Tensor((3, 3), dtype="float32") = R.multiply(lv6, 
R.const(2, "float32"))
-                lv8: R.Tensor((3, 3), dtype="float32") = R.multiply(lv7, 
R.const(2, "float32"))
-                lv9: R.Tensor((3, 3), dtype="float32") = R.multiply(lv6, lv8)
-                gv: R.Tensor((), dtype="float32") = R.sum(lv9, axis=None, 
keepdims=False)
-                gv_adjoint: R.Tensor((), dtype="float32") = 
R.ones(R.shape([]), dtype="float32")
-                lv9_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
-                lv7_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv6, 
R.const(2, "float32"))
-                lv8_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv7_cp, 
R.const(2, "float32"))
-                lv6_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv9_adjoint, lv8_cp)
-                lv8_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv9_adjoint, lv6)
-                lv7_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv8_adjoint, R.const(2, "float32"))
-                lv1_1: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv7_adjoint, R.const(2, "float32"))
-                lv6_adjoint1: R.Tensor((3, 3), dtype="float32") = 
R.add(lv6_adjoint, lv1_1)
-                lv4_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3, 
R.const(2, "float32"))
-                lv5_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv4_cp, 
R.const(2, "float32"))
-                lv3_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv6_adjoint1, lv5_cp)
-                lv5_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv6_adjoint1, lv3)
-                lv4_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv5_adjoint, R.const(2, "float32"))
-                lv4_1: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv4_adjoint, R.const(2, "float32"))
-                lv3_adjoint1: R.Tensor((3, 3), dtype="float32") = 
R.add(lv3_adjoint, lv4_1)
-                lv1_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(x, 
R.const(2, "float32"))
-                lv2_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1_cp, 
R.const(2, "float32"))
-                x_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv3_adjoint1, lv2_cp)
-                lv2_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv3_adjoint1, x)
-                lv1_adjoint: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv2_adjoint, R.const(2, "float32"))
-                lv7_1: R.Tensor((3, 3), dtype="float32") = 
R.multiply(lv1_adjoint, R.const(2, "float32"))
-                x_adjoint1: R.Tensor((3, 3), dtype="float32") = 
R.add(x_adjoint, lv7_1)
-                x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint1
+                lv1: R.Tensor((3, 3), "float32") = R.multiply(x, R.const(2, 
"float32"))
+                lv2: R.Tensor((3, 3), "float32") = R.multiply(lv1, R.const(2, 
"float32"))
+                lv3: R.Tensor((3, 3), "float32") = R.multiply(x, lv2)
+                lv4: R.Tensor((3, 3), "float32") = R.multiply(lv3, R.const(2, 
"float32"))
+                lv5: R.Tensor((3, 3), "float32") = R.multiply(lv4, R.const(2, 
"float32"))
+                lv6: R.Tensor((3, 3), "float32") = R.multiply(lv3, lv5)
+                lv7: R.Tensor((3, 3), "float32") = R.multiply(lv6, R.const(2, 
"float32"))
+                lv8: R.Tensor((3, 3), "float32") = R.multiply(lv7, R.const(2, 
"float32"))
+                lv9: R.Tensor((3, 3), "float32") = R.multiply(lv6, lv8)
+                gv: R.Tensor((), "float32") = R.sum(lv9, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]), 
"float32")
+                lv9_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+                lv7_cp: R.Tensor((3, 3), "float32") = R.multiply(lv6, 
R.const(2, "float32"))
+                lv8_cp: R.Tensor((3, 3), "float32") = R.multiply(lv7_cp, 
R.const(2, "float32"))
+                lv6_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv9_adjoint, lv8_cp)
+                lv8_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv9_adjoint, lv6)
+                lv7_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv8_adjoint, R.const(2, "float32"))
+                lv1_1: R.Tensor((3, 3), "float32") = R.multiply(lv7_adjoint, 
R.const(2, "float32"))
+                lv6_adjoint1: R.Tensor((3, 3), "float32") = R.add(lv6_adjoint, 
lv1_1)
+                lv4_cp: R.Tensor((3, 3), "float32") = R.multiply(lv3, 
R.const(2, "float32"))
+                lv5_cp: R.Tensor((3, 3), "float32") = R.multiply(lv4_cp, 
R.const(2, "float32"))
+                lv3_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv6_adjoint1, lv5_cp)
+                lv5_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv6_adjoint1, lv3)
+                lv4_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv5_adjoint, R.const(2, "float32"))
+                lv4_1: R.Tensor((3, 3), "float32") = R.multiply(lv4_adjoint, 
R.const(2, "float32"))
+                lv3_adjoint1: R.Tensor((3, 3), "float32") = R.add(lv3_adjoint, 
lv4_1)
+                lv1_cp: R.Tensor((3, 3), "float32") = R.multiply(x, R.const(2, 
"float32"))
+                lv2_cp: R.Tensor((3, 3), "float32") = R.multiply(lv1_cp, 
R.const(2, "float32"))
+                x_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv3_adjoint1, lv2_cp)
+                lv2_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv3_adjoint1, x)
+                lv1_adjoint: R.Tensor((3, 3), "float32") = 
R.multiply(lv2_adjoint, R.const(2, "float32"))
+                lv7_1: R.Tensor((3, 3), "float32") = R.multiply(lv1_adjoint, 
R.const(2, "float32"))
+                x_adjoint1: R.Tensor((3, 3), "float32") = R.add(x_adjoint, 
lv7_1)
+                x_adjoint_out: R.Tensor((3, 3), "float32") = x_adjoint1
                 R.output(gv, x_adjoint_out)
             return (gv, (x_adjoint_out,))
 
         @R.function
-        def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), 
dtype="float32"):
+        def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
             with R.dataflow():
                 lv = R.grad.start_checkpoint(x)
                 lv1 = R.multiply(lv, R.const(2, "float32"))
@@ -451,7 +452,7 @@ def test_dag():
                 lv12 = R.multiply(lv11, R.const(2, "float32"))
                 lv13 = R.grad.end_checkpoint(lv12)
                 lv14 = R.multiply(lv9, lv13)
-                gv: R.Tensor((), dtype="float32") = R.sum(lv14, axis=None, 
keepdims=False)
+                gv: R.Tensor((), "float32") = R.sum(lv14, axis=None, 
keepdims=False)
                 R.output(gv)
             return gv
     # fmt: on
@@ -474,9 +475,9 @@ def test_checkpoint_api():
     x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
     with bb.function("main", [x]):
         with bb.dataflow():
-            lv1 = bb.emit(checkpoint(func1, x))
+            lv1 = bb.emit(nn.checkpoint(func1, x))
             lv2 = bb.emit(relax.op.power(lv1, relax.const(3, "float32")))
-            lv3 = bb.emit_output(checkpoint(func2, lv2))
+            lv3 = bb.emit_output(nn.checkpoint(func2, lv2))
         bb.emit_func_output(lv3)
 
     # fmt: off
@@ -516,7 +517,7 @@ def test_checkpoint_tree():
     with bb.function("main", [x, y, z, u, v]):
         with bb.dataflow():
             lv1 = bb.emit(x * y)
-            cp = checkpoint(func, lv1, z, u, v)
+            cp = nn.checkpoint(func, lv1, z, u, v)
             lv2 = bb.emit(cp[0])
             lv3 = bb.emit(cp[1])
             lv4 = bb.emit(lv2 * lv3)
@@ -559,11 +560,11 @@ def test_checkpoint_dag():
     x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
     with bb.function("main", [x]):
         with bb.dataflow():
-            lv1 = bb.emit(checkpoint(func, x))
+            lv1 = bb.emit(nn.checkpoint(func, x))
             lv2 = bb.emit(x * lv1)
-            lv3 = bb.emit(checkpoint(func, lv2))
+            lv3 = bb.emit(nn.checkpoint(func, lv2))
             lv4 = bb.emit(lv2 * lv3)
-            lv5 = bb.emit(checkpoint(func, lv4))
+            lv5 = bb.emit(nn.checkpoint(func, lv4))
             lv6 = bb.emit(lv4 * lv5)
             gv = bb.emit_output(relax.op.sum(lv6))
         bb.emit_func_output(gv)
@@ -572,7 +573,7 @@ def test_checkpoint_dag():
     @I.ir_module
     class Expected:
         @R.function
-        def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((), 
dtype="float32"):
+        def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
             with R.dataflow():
                 lv = R.grad.start_checkpoint(x)
                 lv1 = R.multiply(lv, R.const(2, "float32"))
@@ -589,7 +590,7 @@ def test_checkpoint_dag():
                 lv12 = R.multiply(lv11, R.const(2, "float32"))
                 lv13 = R.grad.end_checkpoint(lv12)
                 lv14 = R.multiply(lv9, lv13)
-                gv: R.Tensor((), dtype="float32") = R.sum(lv14, axis=None, 
keepdims=False)
+                gv: R.Tensor((), "float32") = R.sum(lv14, axis=None, 
keepdims=False)
                 R.output(gv)
             return gv
     # fmt: on
@@ -605,8 +606,8 @@ def test_checkpoint_sequential():
     x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
     with bb.function("main", [x]):
         with bb.dataflow():
-            lv1 = emit_checkpoint_sequential([func] * 5, 2, x)
-            lv2 = emit_checkpoint_sequential([func] * 4, 2, lv1)
+            lv1 = nn.emit_checkpoint_sequential([func] * 5, 2, x)
+            lv2 = nn.emit_checkpoint_sequential([func] * 4, 2, lv1)
             gv = bb.emit_output(lv2)
         bb.emit_func_output(gv)
 
@@ -614,24 +615,24 @@ def test_checkpoint_sequential():
     @I.ir_module
     class Expected:
         @R.function
-        def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), 
dtype="float32"):
+        def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((3, 3), 
"float32"):
             with R.dataflow():
-                x_scp: R.Tensor((3, 3), dtype="float32") = 
R.grad.start_checkpoint(x)
-                lv: R.Tensor((3, 3), dtype="float32") = R.add(x_scp, x_scp)
-                lv1: R.Tensor((3, 3), dtype="float32") = R.add(lv, lv)
-                lv1_ecp: R.Tensor((3, 3), dtype="float32") = 
R.grad.end_checkpoint(lv1)
-                lv1_ecp_scp: R.Tensor((3, 3), dtype="float32") = 
R.grad.start_checkpoint(lv1_ecp)
-                lv2: R.Tensor((3, 3), dtype="float32") = R.add(lv1_ecp_scp, 
lv1_ecp_scp)
-                lv3: R.Tensor((3, 3), dtype="float32") = R.add(lv2, lv2)
-                lv3_ecp: R.Tensor((3, 3), dtype="float32") = 
R.grad.end_checkpoint(lv3)
-                lv4: R.Tensor((3, 3), dtype="float32") = R.add(lv3_ecp, 
lv3_ecp)
-                lv4_scp: R.Tensor((3, 3), dtype="float32") = 
R.grad.start_checkpoint(lv4)
-                lv5: R.Tensor((3, 3), dtype="float32") = R.add(lv4_scp, 
lv4_scp)
-                lv6: R.Tensor((3, 3), dtype="float32") = R.add(lv5, lv5)
-                lv6_ecp: R.Tensor((3, 3), dtype="float32") = 
R.grad.end_checkpoint(lv6)
-                lv7: R.Tensor((3, 3), dtype="float32") = R.add(lv6_ecp, 
lv6_ecp)
-                lv8: R.Tensor((3, 3), dtype="float32") = R.add(lv7, lv7)
-                gv: R.Tensor((3, 3), dtype="float32") = lv8
+                x_scp: R.Tensor((3, 3), "float32") = R.grad.start_checkpoint(x)
+                lv: R.Tensor((3, 3), "float32") = R.add(x_scp, x_scp)
+                lv1: R.Tensor((3, 3), "float32") = R.add(lv, lv)
+                lv1_ecp: R.Tensor((3, 3), "float32") = 
R.grad.end_checkpoint(lv1)
+                lv1_ecp_scp: R.Tensor((3, 3), "float32") = 
R.grad.start_checkpoint(lv1_ecp)
+                lv2: R.Tensor((3, 3), "float32") = R.add(lv1_ecp_scp, 
lv1_ecp_scp)
+                lv3: R.Tensor((3, 3), "float32") = R.add(lv2, lv2)
+                lv3_ecp: R.Tensor((3, 3), "float32") = 
R.grad.end_checkpoint(lv3)
+                lv4: R.Tensor((3, 3), "float32") = R.add(lv3_ecp, lv3_ecp)
+                lv4_scp: R.Tensor((3, 3), "float32") = 
R.grad.start_checkpoint(lv4)
+                lv5: R.Tensor((3, 3), "float32") = R.add(lv4_scp, lv4_scp)
+                lv6: R.Tensor((3, 3), "float32") = R.add(lv5, lv5)
+                lv6_ecp: R.Tensor((3, 3), "float32") = 
R.grad.end_checkpoint(lv6)
+                lv7: R.Tensor((3, 3), "float32") = R.add(lv6_ecp, lv6_ecp)
+                lv8: R.Tensor((3, 3), "float32") = R.add(lv7, lv7)
+                gv: R.Tensor((3, 3), "float32") = lv8
                 R.output(gv)
             return gv
     # fmt: on
@@ -647,8 +648,8 @@ def test_checkpoint_sequential_checkpoint_last():
     x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
     with bb.function("main", [x]):
         with bb.dataflow():
-            lv1 = emit_checkpoint_sequential([func] * 5, 2, x, 
checkpoint_last=True)
-            lv2 = emit_checkpoint_sequential([func] * 4, 2, lv1, 
checkpoint_last=True)
+            lv1 = nn.emit_checkpoint_sequential([func] * 5, 2, x, 
checkpoint_last=True)
+            lv2 = nn.emit_checkpoint_sequential([func] * 4, 2, lv1, 
checkpoint_last=True)
             gv = bb.emit_output(lv2)
         bb.emit_func_output(gv)
 
@@ -656,28 +657,28 @@ def test_checkpoint_sequential_checkpoint_last():
     @I.ir_module
     class Expected:
         @R.function
-        def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3), 
dtype="float32"):
+        def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((3, 3), 
"float32"):
             with R.dataflow():
-                x_scp: R.Tensor((3, 3), dtype="float32") = 
R.grad.start_checkpoint(x)
-                lv: R.Tensor((3, 3), dtype="float32") = R.add(x_scp, x_scp)
-                lv1: R.Tensor((3, 3), dtype="float32") = R.add(lv, lv)
-                lv1_ecp: R.Tensor((3, 3), dtype="float32") = 
R.grad.end_checkpoint(lv1)
-                lv1_ecp_scp: R.Tensor((3, 3), dtype="float32") = 
R.grad.start_checkpoint(lv1_ecp)
-                lv2: R.Tensor((3, 3), dtype="float32") = R.add(lv1_ecp_scp, 
lv1_ecp_scp)
-                lv3: R.Tensor((3, 3), dtype="float32") = R.add(lv2, lv2)
-                lv3_ecp: R.Tensor((3, 3), dtype="float32") = 
R.grad.end_checkpoint(lv3)
-                lv3_ecp_scp: R.Tensor((3, 3), dtype="float32") = 
R.grad.start_checkpoint(lv3_ecp)
-                lv4: R.Tensor((3, 3), dtype="float32") = R.add(lv3_ecp_scp, 
lv3_ecp_scp)
-                lv4_ecp: R.Tensor((3, 3), dtype="float32") = 
R.grad.end_checkpoint(lv4)
-                lv4_ecp_scp: R.Tensor((3, 3), dtype="float32") = 
R.grad.start_checkpoint(lv4_ecp)
-                lv5: R.Tensor((3, 3), dtype="float32") = R.add(lv4_ecp_scp, 
lv4_ecp_scp)
-                lv6: R.Tensor((3, 3), dtype="float32") = R.add(lv5, lv5)
-                lv6_ecp: R.Tensor((3, 3), dtype="float32") = 
R.grad.end_checkpoint(lv6)
-                lv6_ecp_scp: R.Tensor((3, 3), dtype="float32") = 
R.grad.start_checkpoint(lv6_ecp)
-                lv7: R.Tensor((3, 3), dtype="float32") = R.add(lv6_ecp_scp, 
lv6_ecp_scp)
-                lv8: R.Tensor((3, 3), dtype="float32") = R.add(lv7, lv7)
-                lv8_ecp: R.Tensor((3, 3), dtype="float32") = 
R.grad.end_checkpoint(lv8)
-                gv: R.Tensor((3, 3), dtype="float32") = lv8_ecp
+                x_scp: R.Tensor((3, 3), "float32") = R.grad.start_checkpoint(x)
+                lv: R.Tensor((3, 3), "float32") = R.add(x_scp, x_scp)
+                lv1: R.Tensor((3, 3), "float32") = R.add(lv, lv)
+                lv1_ecp: R.Tensor((3, 3), "float32") = 
R.grad.end_checkpoint(lv1)
+                lv1_ecp_scp: R.Tensor((3, 3), "float32") = 
R.grad.start_checkpoint(lv1_ecp)
+                lv2: R.Tensor((3, 3), "float32") = R.add(lv1_ecp_scp, 
lv1_ecp_scp)
+                lv3: R.Tensor((3, 3), "float32") = R.add(lv2, lv2)
+                lv3_ecp: R.Tensor((3, 3), "float32") = 
R.grad.end_checkpoint(lv3)
+                lv3_ecp_scp: R.Tensor((3, 3), "float32") = 
R.grad.start_checkpoint(lv3_ecp)
+                lv4: R.Tensor((3, 3), "float32") = R.add(lv3_ecp_scp, 
lv3_ecp_scp)
+                lv4_ecp: R.Tensor((3, 3), "float32") = 
R.grad.end_checkpoint(lv4)
+                lv4_ecp_scp: R.Tensor((3, 3), "float32") = 
R.grad.start_checkpoint(lv4_ecp)
+                lv5: R.Tensor((3, 3), "float32") = R.add(lv4_ecp_scp, 
lv4_ecp_scp)
+                lv6: R.Tensor((3, 3), "float32") = R.add(lv5, lv5)
+                lv6_ecp: R.Tensor((3, 3), "float32") = 
R.grad.end_checkpoint(lv6)
+                lv6_ecp_scp: R.Tensor((3, 3), "float32") = 
R.grad.start_checkpoint(lv6_ecp)
+                lv7: R.Tensor((3, 3), "float32") = R.add(lv6_ecp_scp, 
lv6_ecp_scp)
+                lv8: R.Tensor((3, 3), "float32") = R.add(lv7, lv7)
+                lv8_ecp: R.Tensor((3, 3), "float32") = 
R.grad.end_checkpoint(lv8)
+                gv: R.Tensor((3, 3), "float32") = lv8_ecp
                 R.output(gv)
             return gv
     # fmt: on
@@ -685,5 +686,70 @@ def test_checkpoint_sequential_checkpoint_last():
     assert_structural_equal(bb.get(), Expected)
 
 
+def test_checkpoint_dag():
+    """Comp. graph is a DAG with only one output. Here we only test the simple 
case: comp. graph
+    is a sequence of sub-graphs, and the checkpoints are the intersections of 
connected
+    subgraphs."""
+
+    def func(x):
+        return x * relax.const(2, "float32") * relax.const(2, "float32")
+
+    bb = BlockBuilder()
+    x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
+    with bb.function("main", [x]):
+        with bb.dataflow():
+            lv1 = bb.emit(nn.checkpoint(func, x))
+            lv2 = bb.emit(x * lv1)
+            lv3 = bb.emit(nn.checkpoint(func, lv2))
+            lv4 = bb.emit(lv2 * lv3)
+            lv5 = bb.emit(nn.checkpoint(func, lv4))
+            lv6 = bb.emit(lv4 * lv5)
+            gv = bb.emit_output(relax.op.sum(lv6))
+        bb.emit_func_output(gv)
+
+
+def test_checkpoint_with_intermediate_require_grads():
+    def func(x):
+        return x * x * x
+
+    bb = BlockBuilder()
+    x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
+    with bb.function("main", [x]):
+        with bb.dataflow():
+            lv1 = nn.emit_checkpoint(func, x)
+            gv = bb.emit_output(relax.op.sum(lv1))
+        bb.emit_func_output(gv)
+
+    # fmt: off
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main_adjoint(x: R.Tensor((3, 3), "float32")) -> 
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
+            with R.dataflow():
+                lv: R.Tensor((3, 3), "float32") = R.multiply(x, x)
+                lv1: R.Tensor((3, 3), "float32") = R.multiply(lv, x)
+                gv: R.Tensor((), "float32") = R.sum(lv1, axis=None, 
keepdims=False)
+                gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]), 
"float32")
+                lv1_adjoint: R.Tensor((3, 3), "float32") = 
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+                lv1_adjoint_out: R.Tensor((3, 3), "float32") = lv1_adjoint
+                R.output(gv, lv1_adjoint_out)
+            return (gv, (lv1_adjoint_out,))
+
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
+            with R.dataflow():
+                x_scp: R.Tensor((3, 3), "float32") = R.grad.start_checkpoint(x)
+                lv: R.Tensor((3, 3), "float32") = R.multiply(x_scp, x_scp)
+                lv1: R.Tensor((3, 3), "float32") = R.multiply(lv, x_scp)
+                lv1_ecp: R.Tensor((3, 3), "float32") = 
R.grad.end_checkpoint(lv1)
+                gv: R.Tensor((), "float32") = R.sum(lv1_ecp, axis=None, 
keepdims=False)
+                R.output(gv)
+            return gv
+    # fmt: on
+
+    After = relax.transform.Gradient("main", lv1)(bb.get())
+    assert_structural_equal(After, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to