This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e7d12be07a [Unity][Training] Support intermediate vars as
require_grads for Gradient pass (#16011)
e7d12be07a is described below
commit e7d12be07aba6e6dedade041d47146df33531fa4
Author: Yixin Dong <[email protected]>
AuthorDate: Mon Nov 6 08:37:29 2023 -0800
[Unity][Training] Support intermediate vars as require_grads for Gradient
pass (#16011)
---
python/tvm/relax/op/_op_gradient.py | 8 +-
python/tvm/relax/op/manipulate.py | 2 +-
src/relax/transform/gradient.cc | 111 +++--
tests/python/relax/test_transform_gradient.py | 58 +++
.../relax/test_transform_gradient_checkpoint.py | 548 ++++++++++++---------
5 files changed, 434 insertions(+), 293 deletions(-)
diff --git a/python/tvm/relax/op/_op_gradient.py
b/python/tvm/relax/op/_op_gradient.py
index 2873c70ba7..1b0ebfd5e4 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -256,8 +256,8 @@ def maximum_grad(
y = orig_call.args[1]
zero = relax.const(0, _get_dtype(x))
return [
- where(less(x, y), zero, output_grad),
- where(greater_equal(x, y), zero, output_grad),
+ _fit_shape(ctx, where(less(x, y), zero, output_grad), x),
+ _fit_shape(ctx, where(greater_equal(x, y), zero, output_grad), y),
]
@@ -280,8 +280,8 @@ def minimum_grad(
y = orig_call.args[1]
zero = relax.const(0, _get_dtype(x))
return [
- where(greater_equal(x, y), zero, output_grad),
- where(less(x, y), zero, output_grad),
+ _fit_shape(ctx, where(greater_equal(x, y), zero, output_grad), x),
+ _fit_shape(ctx, where(less(x, y), zero, output_grad), y),
]
diff --git a/python/tvm/relax/op/manipulate.py
b/python/tvm/relax/op/manipulate.py
index 3dfa371e42..9bd99020e9 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -170,7 +170,7 @@ def permute_dims(x: Expr, axes: Optional[List[int]] = None)
-> Expr:
The input data to the operator.
axes : Optional[List[int]]
- The target axes order, reverse order if not specified.
+ The target axes order. If not specified, permute_dims will reverse the
order of all axes.
Returns
-------
diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
index 8560db6a04..2bf858c98d 100644
--- a/src/relax/transform/gradient.cc
+++ b/src/relax/transform/gradient.cc
@@ -98,18 +98,22 @@ class CheckpointCollector : private ExprMutator {
* and the end_checkpoint bindings.
*
* \param func The original function
- * \return The function with all start_checkpoint and end_checkpoint
bindings removed, and a
- * VarIdSet containing all checkpointed vars.
+ * \return The function with all start_checkpoint and end_checkpoint
bindings removed.
*/
- static std::pair<Function, VarIdSet> Collect(const Function& func) {
+ Function Transform(const Function& func) {
auto collector = CheckpointCollector();
- return std::make_pair(Downcast<Function>(collector.VisitExpr(func)),
collector.checkpoints_);
+ return Downcast<Function>(this->VisitExpr(func));
}
+ // checkpointed vars
+ VarIdSet checkpoints;
+ // mapping from vars that are wrapped in start_checkpoint or end_checkpoint
to the original vars
+ std::unordered_map<Id, Var, ObjectPtrHash, ObjectPtrEqual> var_mapping;
+
private:
Expr VisitExpr_(const FunctionNode* func) final {
for (auto var : func->params) {
- checkpoints_.insert(var->vid);
+ checkpoints.insert(var->vid);
}
return ExprMutator::VisitExpr_(func);
@@ -119,7 +123,7 @@ class CheckpointCollector : private ExprMutator {
static const auto s_cp = Op::Get("relax.grad.start_checkpoint");
static const auto e_cp = Op::Get("relax.grad.end_checkpoint");
- // If every variable that the variable of binding relys on is either
+ // If every variable that the variable of binding relies on is either
// 1) the output of end_checkpoint; 2) checkpointed
// then the variable of binding will be checkpointed
auto var_binding = binding.as<VarBindingNode>();
@@ -131,12 +135,12 @@ class CheckpointCollector : private ExprMutator {
PostOrderVisit(var_binding->value, [this,
&all_inner_var_checkpointed](const Expr& expr) {
if (auto var = expr.as<VarNode>()) {
all_inner_var_checkpointed &=
- (checkpoints_.count(var->vid) != 0 || e_vars_.count(var->vid) !=
0);
+ (checkpoints.count(var->vid) != 0 || e_vars_.count(var->vid) !=
0);
}
});
if (all_inner_var_checkpointed) {
- checkpoints_.insert(var_binding->var->vid);
+ checkpoints.insert(var_binding->var->vid);
}
}
@@ -162,10 +166,11 @@ class CheckpointCollector : private ExprMutator {
} else {
this->var_remap_[binding->var->vid] = orig_var;
}
+ var_mapping[binding->var->vid] = orig_var;
if (value->op == s_cp) {
// mark the original var to be checkpointed
- checkpoints_.insert(orig_var->vid);
+ checkpoints.insert(orig_var->vid);
} else if (value->op == e_cp) {
e_vars_.insert(binding->var->vid);
}
@@ -174,7 +179,7 @@ class CheckpointCollector : private ExprMutator {
}
}
- VarIdSet checkpoints_;
+ // vars that are the output of end_checkpoint
VarIdSet e_vars_;
};
@@ -233,6 +238,9 @@ class CheckpointGenerator : private ExprMutator {
// Visit the use-site of a defined Var
Expr VisitExpr_(const VarNode* op) final { return VisitVar(GetRef<Var>(op));
}
+ // Visit the use-site of a defined DataflowVar
+ Expr VisitExpr_(const DataflowVarNode* op) final { return
VisitVar(GetRef<Var>(op)); }
+
Expr VisitVar(const Var& var) {
auto it = checkpoint_map_.find(var);
if (it != checkpoint_map_.end()) {
@@ -286,9 +294,10 @@ class BackwardBindingGenerator : private ExprVisitor {
static Expr Generate(const BlockBuilder& builder, const DataflowBlock&
forward_block,
const Array<Var>& require_grads, const Var& target_var,
const Array<Var>& orig_params, const Expr&
orig_return_value,
- const VarIdSet& checkpoints) {
- CheckpointGenerator checkpoint_generator(builder, orig_params,
forward_block, checkpoints);
- BackwardBindingGenerator generator(builder, checkpoint_generator);
+ const CheckpointCollector& cp_collector) {
+ CheckpointGenerator checkpoint_generator(builder, orig_params,
forward_block,
+ cp_collector.checkpoints);
+ BackwardBindingGenerator generator(builder, cp_collector,
checkpoint_generator);
// Initialize the adjoint of target_var as ones op. We have already
checked the target.
auto* target_sinfo = GetStructInfoAs<TensorStructInfoNode>(target_var);
@@ -304,8 +313,11 @@ class BackwardBindingGenerator : private ExprVisitor {
private:
explicit BackwardBindingGenerator(const BlockBuilder& builder,
+ const CheckpointCollector& cp_collector,
const CheckpointGenerator&
checkpoint_generator)
- : builder_(builder), checkpoint_generator_(checkpoint_generator) {}
+ : builder_(builder),
+ cp_collector_(cp_collector),
+ checkpoint_generator_(checkpoint_generator) {}
void VisitBinding(const Binding& binding) final {
// TODO(chaofan, yixin): support other types of bindings
@@ -469,6 +481,11 @@ class BackwardBindingGenerator : private ExprVisitor {
Array<Expr> out_adjoints;
for (Var var : require_grads) {
+ // var might be wrapped in start_checkpoint or end_checkpoint, so we
should find the original
+ // var first
+ if (cp_collector_.var_mapping.count(var->vid)) {
+ var = cp_collector_.var_mapping[var->vid];
+ }
// If the var don't have adjoint var, it do not contribute to the
target. So its adjoint is
// zeros
auto it = adjoint_var_map_.find(var);
@@ -575,6 +592,8 @@ class BackwardBindingGenerator : private ExprVisitor {
BlockBuilder builder_;
// Forward Var to its adjoint Var
Map<Var, Var> adjoint_var_map_;
+ // information collected by CheckpointCollector
+ CheckpointCollector cp_collector_;
// The generator for checkpoint bindings
CheckpointGenerator checkpoint_generator_;
};
@@ -586,48 +605,42 @@ class GradientMutator : private ExprMutator {
// Step 1. Copy function
auto* old_func = mod->Lookup(func_name).as<FunctionNode>();
CHECK(old_func) << func_name << "is not a Relax Function";
- auto new_func = CopyWithNewVars(GetRef<Function>(old_func));
+ auto copier = FunctionCopier();
+ auto new_func = copier.Copy(GetRef<Function>(old_func));
// Step 2. Handle the checkpoints and eliminate start_checkpoint and
end_checkpoint ops
- auto checkpoint_collected = CheckpointCollector::Collect(new_func);
- new_func = checkpoint_collected.first;
- auto checkpoints = checkpoint_collected.second;
-
- // Step 3. Collect call_tir_with_grad information
- auto tir_grad_collected = CallTIRWithGradEliminator::Transform(new_func);
+ auto cp_collector = CheckpointCollector();
+ new_func = cp_collector.Transform(new_func);
- // Step 4. Handle require_grads
+ // Step 3. Handle require_grads
// When require_grads is not specified, it would be set to all params of
the function
- if (require_grads) {
- CheckRequireGrads(require_grads.value(), old_func->params, func_name);
+ if (!require_grads) {
+ require_grads = new_func->params;
+ } else {
+ require_grads = CheckAndMapRequireGrads(require_grads.value(),
copier.GetVarMap(), func_name);
}
- // then map the parameter list into new params
- auto require_grads_value =
require_grads.value_or(old_func->params).Map([&](const Var& v) {
- return new_func->params[std::find(old_func->params.begin(),
old_func->params.end(), v) -
- old_func->params.begin()];
- });
- // Step 5. Generate the adjoint function, use RemoveAllUnused to simplify
it, and then return
+ // Step 4. Generate the adjoint function, use RemoveAllUnused to simplify
it, and then return
// the IRModule with the adjoint function
- return GradientMutator(mod, require_grads_value, target_index, checkpoints)
+ return GradientMutator(mod, require_grads.value(), target_index,
cp_collector)
.AddAdjointFunction(new_func, func_name, true);
}
private:
GradientMutator(const IRModule& module, const Array<Var>& require_grads, int
target_index,
- const VarIdSet& checkpoints)
+ const CheckpointCollector& cp_collector)
: ExprMutator(module),
require_grads_(require_grads),
- checkpoints_(checkpoints),
+ cp_collector_(cp_collector),
target_index_(target_index) {}
// Add the adjoint function of func to the IRModule using BlockBuilder
IRModule AddAdjointFunction(const Function& func, const String& func_name,
bool remove_all_unused = true) {
- // Step 5.1 forward -> forward + backward
+ // Step 4.1 forward -> forward + backward
auto new_func = Downcast<Function>(VisitExpr(func));
- // Step 5.2 Convert call_tir_with_grad nodes into call_tir nodes
+ // Step 4.2 Convert call_tir_with_grad nodes into call_tir nodes
// because call_tir_with_grad nodes is not actually implemented
new_func = CallTIRWithGradEliminator::Transform(new_func);
@@ -635,12 +648,12 @@ class GradientMutator : private ExprMutator {
new_func = Downcast<Function>(RemoveAllUnused(new_func));
}
- // Step 5.3 mark the transformed function as public
+ // Step 4.3 mark the transformed function as public
// because the original function may be public, and have gsymbol attribute
as func_name
auto new_func_name = func_name + "_adjoint";
auto new_func_with_gsymbol = WithAttr(new_func, tvm::attr::kGlobalSymbol,
new_func_name);
- // Step 5.4 Add the transformed function to IRModule
+ // Step 4.4 Add the transformed function to IRModule
builder_->AddFunction(new_func_with_gsymbol, new_func_name);
return builder_->GetContextIRModule();
}
@@ -679,7 +692,7 @@ class GradientMutator : private ExprMutator {
// generate backward bindings and the return value
return_expr_ = BackwardBindingGenerator::Generate(builder_,
GetRef<DataflowBlock>(block),
require_grads_,
target_var_, orig_params_,
- orig_return_expr_,
checkpoints_);
+ orig_return_expr_,
cp_collector_);
return builder_->EndBlock();
}
@@ -700,7 +713,8 @@ class GradientMutator : private ExprMutator {
target_var_ = GetRef<Var>(var);
} else if (auto* tuple = e.as<TupleNode>()) {
CHECK(target_index >= 0 && target_index <
static_cast<int>(tuple->fields.size()))
- << "target_index should be in the range of the number of return
values of the function. "
+ << "target_index should be in the range of the number of return
values of the "
+ "function. "
"But the specified target_index is "
<< target_index << ", while the number of return values is " <<
tuple->fields.size();
auto* var = tuple->fields[target_index].as<VarNode>();
@@ -721,30 +735,33 @@ class GradientMutator : private ExprMutator {
// Check every Var in require_grads:
// 1. there should be no duplicate var
- // 2. every var should be a parameter of the function
+ // 2. every var should be a parameter or a intermediate var in the function
// 3. the type of the input var should be Tensor of floating point dtype, or
Tuple of that
- static void CheckRequireGrads(const Array<Var>& require_grads, const
Array<Var>& func_params,
- const String& func_name) {
+ static Array<Var> CheckAndMapRequireGrads(const Array<Var>& require_grads,
+ const Map<Var, Var>& var_map,
const String& func_name) {
VarIdSet var_set;
+ Array<Var> mapped_vars;
for (const auto& var : require_grads) {
- CHECK(std::find(func_params.begin(), func_params.end(), var) !=
func_params.end())
- << "There is no Var named " << var->name_hint() << " in the
parameters of the function "
- << func_name;
+ auto it = var_map.find(var);
+ CHECK(it != var_map.end()) << "There is no Var named " <<
var->name_hint()
+ << " in the function " << func_name;
CHECK_EQ(var_set.count(var->vid), 0)
<< "Var " << var->name_hint() << " appears more than once";
var_set.emplace(var->vid);
+ mapped_vars.push_back((*it).second);
CHECK(IsNestedTensorConditioned(GetStructInfo(var), IsFloatTensorSInfo))
<< "Only Tensors of floating point dtype or Tuples of float "
"Tensors can require gradients, but the StructInfo of Var "
<< var->name_hint() << " is " << GetStructInfo(var);
}
+ return mapped_vars;
}
// differentiation sources
Array<Var> require_grads_;
- // checkpoint
- VarIdSet checkpoints_;
+ // information collected by CheckpointCollector
+ CheckpointCollector cp_collector_;
// the differentiation target
int target_index_;
Var target_var_;
diff --git a/tests/python/relax/test_transform_gradient.py
b/tests/python/relax/test_transform_gradient.py
index b96932f8c5..c4e2f9d526 100644
--- a/tests/python/relax/test_transform_gradient.py
+++ b/tests/python/relax/test_transform_gradient.py
@@ -309,6 +309,64 @@ def test_target_index():
assert_structural_equal(After, Expected)
+def test_intermediate_var_require_grads():
+ x = relax.Var("x", R.Tensor((3, 3), "float32"))
+ y = relax.Var("y", R.Tensor((3, 3), "float32"))
+
+ bb = relax.BlockBuilder()
+ with bb.function("main", [x, y]):
+ with bb.dataflow():
+ lv0 = bb.emit(x * x)
+ lv1 = bb.emit(lv0 * y)
+ lv2 = bb.emit(lv1 * y)
+ gv0 = bb.emit_output(relax.op.sum(lv2))
+ bb.emit_func_output(gv0)
+
+ Before = bb.get()
+
+ # fmt: off
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3,
3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"),
R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3), dtype="float32"),
R.Tensor((), dtype="float32"))):
+ with R.dataflow():
+ lv: R.Tensor((3, 3), dtype="float32") = R.multiply(x, x)
+ lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(lv, y)
+ lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, y)
+ gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), dtype="float32") =
R.ones(R.shape([]), dtype="float32")
+ lv2_adjoint: R.Tensor((3, 3), dtype="float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+ lv1_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv2_adjoint, y)
+ lv_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv1_adjoint, y)
+ x_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv_adjoint, x)
+ lv1_1: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv_adjoint, x)
+ x_adjoint1: R.Tensor((3, 3), dtype="float32") =
R.add(x_adjoint, lv1_1)
+ x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint1
+ lv1_adjoint_out: R.Tensor((3, 3), dtype="float32") =
lv1_adjoint
+ gv_adjoint_out: R.Tensor((), dtype="float32") = gv_adjoint
+ R.output(gv, x_adjoint_out, lv1_adjoint_out, gv_adjoint_out)
+ return (gv, (x_adjoint_out, lv1_adjoint_out, gv_adjoint_out))
+
+ @R.function
+ def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3),
dtype="float32")) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((3, 3), dtype="float32") = R.multiply(x, x)
+ lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(lv, y)
+ lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, y)
+ gv: R.Tensor((), dtype="float32") = R.sum(lv2, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+ # fmt: on
+
+ After = relax.transform.Gradient("main", [x, lv1, gv0])(Before)
+ assert_structural_equal(After, Expected)
+
+ # z does not occur in function
+ z = relax.Var("z", R.Tensor((3, 3), "float32"))
+ with pytest.raises(TVMError):
+ relax.transform.Gradient("main", [x, lv1, z])(Before)
+
+
def test_tuple():
# fmt: off
@I.ir_module
diff --git a/tests/python/relax/test_transform_gradient_checkpoint.py
b/tests/python/relax/test_transform_gradient_checkpoint.py
index 3e94125b77..9de62d341c 100644
--- a/tests/python/relax/test_transform_gradient_checkpoint.py
+++ b/tests/python/relax/test_transform_gradient_checkpoint.py
@@ -16,12 +16,13 @@
# under the License.
"""Unit tests for gradient with checkpointing."""
import tvm
-from tvm.relax.block_builder import BlockBuilder
-from tvm.relax.testing.nn import checkpoint, emit_checkpoint,
emit_checkpoint_sequential
import tvm.testing
+
from tvm import relax
from tvm.ir.base import assert_structural_equal
-from tvm.script.parser import relax as R, ir as I
+from tvm.relax.block_builder import BlockBuilder
+from tvm.relax.testing import nn
+from tvm.script.parser import ir as I, relax as R
def test_sequential():
@@ -47,51 +48,51 @@ def test_sequential():
@I.ir_module
class Expected:
@R.function
- def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) ->
R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3),
dtype="float32"))):
+ def main_adjoint(x: R.Tensor((3, 3), "float32")) ->
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
with R.dataflow():
- lv1: R.Tensor((3, 3), dtype="float32") = R.power(x, R.const(3,
"float32"))
- lv2: R.Tensor((3, 3), dtype="float32") = R.power(lv1,
R.const(3, "float32"))
- lv3: R.Tensor((3, 3), dtype="float32") = R.power(lv2,
R.const(3, "float32"))
- lv4: R.Tensor((3, 3), dtype="float32") = R.power(lv3,
R.const(3, "float32"))
- gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None,
keepdims=False)
- gv_1: R.Tensor((), dtype="float32") = gv
- gv_adjoint: R.Tensor((), dtype="float32") =
R.ones(R.shape([]), dtype="float32")
- gv_adjoint1: R.Tensor((), dtype="float32") = gv_adjoint
- lv3_cp: R.Tensor((3, 3), dtype="float32") = R.power(lv2,
R.const(3, "float32"))
- lv4_adjoint: R.Tensor((3, 3), dtype="float32") =
R.broadcast_to(gv_adjoint1, R.shape([3, 3]))
- lv: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv4_adjoint, R.const(3, "float32"))
- lv1_1: R.Tensor((), dtype="float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
- lv2_1: R.Tensor((3, 3), dtype="float32") = R.power(lv3_cp,
lv1_1)
- lv3_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv, lv2_1)
- lv6: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv3_adjoint, R.const(3, "float32"))
- lv7: R.Tensor((), dtype="float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
- lv8: R.Tensor((3, 3), dtype="float32") = R.power(lv2, lv7)
- lv2_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv6, lv8)
- lv1_cp: R.Tensor((3, 3), dtype="float32") = R.power(x,
R.const(3, "float32"))
- lv12: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv2_adjoint, R.const(3, "float32"))
- lv13: R.Tensor((), dtype="float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
- lv14: R.Tensor((3, 3), dtype="float32") = R.power(lv1_cp, lv13)
- lv1_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv12, lv14)
- lv18: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv1_adjoint, R.const(3, "float32"))
- lv19: R.Tensor((), dtype="float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
- lv20: R.Tensor((3, 3), dtype="float32") = R.power(x, lv19)
- x_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv18, lv20)
- x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint
+ lv1: R.Tensor((3, 3), "float32") = R.power(x, R.const(3,
"float32"))
+ lv2: R.Tensor((3, 3), "float32") = R.power(lv1, R.const(3,
"float32"))
+ lv3: R.Tensor((3, 3), "float32") = R.power(lv2, R.const(3,
"float32"))
+ lv4: R.Tensor((3, 3), "float32") = R.power(lv3, R.const(3,
"float32"))
+ gv: R.Tensor((), "float32") = R.sum(lv4, axis=None,
keepdims=False)
+ gv_1: R.Tensor((), "float32") = gv
+ gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]),
"float32")
+ gv_adjoint1: R.Tensor((), "float32") = gv_adjoint
+ lv3_cp: R.Tensor((3, 3), "float32") = R.power(lv2, R.const(3,
"float32"))
+ lv4_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint1, R.shape([3, 3]))
+ lv: R.Tensor((3, 3), "float32") = R.multiply(lv4_adjoint,
R.const(3, "float32"))
+ lv1_1: R.Tensor((), "float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
+ lv2_1: R.Tensor((3, 3), "float32") = R.power(lv3_cp, lv1_1)
+ lv3_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv,
lv2_1)
+ lv6: R.Tensor((3, 3), "float32") = R.multiply(lv3_adjoint,
R.const(3, "float32"))
+ lv7: R.Tensor((), "float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
+ lv8: R.Tensor((3, 3), "float32") = R.power(lv2, lv7)
+ lv2_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv6, lv8)
+ lv1_cp: R.Tensor((3, 3), "float32") = R.power(x, R.const(3,
"float32"))
+ lv12: R.Tensor((3, 3), "float32") = R.multiply(lv2_adjoint,
R.const(3, "float32"))
+ lv13: R.Tensor((), "float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
+ lv14: R.Tensor((3, 3), "float32") = R.power(lv1_cp, lv13)
+ lv1_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv12,
lv14)
+ lv18: R.Tensor((3, 3), "float32") = R.multiply(lv1_adjoint,
R.const(3, "float32"))
+ lv19: R.Tensor((), "float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
+ lv20: R.Tensor((3, 3), "float32") = R.power(x, lv19)
+ x_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv18, lv20)
+ x_adjoint_out: R.Tensor((3, 3), "float32") = x_adjoint
R.output(gv_1, x_adjoint_out)
return (gv_1, (x_adjoint_out,))
@R.function
- def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((),
dtype="float32"):
+ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
with R.dataflow():
- x_scp: R.Tensor((3, 3), dtype="float32") =
R.grad.start_checkpoint(x)
- lv1: R.Tensor((3, 3), dtype="float32") = R.power(x_scp,
R.const(3, "float32"))
- lv1_ecp: R.Tensor((3, 3), dtype="float32") =
R.grad.end_checkpoint(lv1)
- lv2: R.Tensor((3, 3), dtype="float32") = R.power(lv1_ecp,
R.const(3, "float32"))
- lv2_scp: R.Tensor((3, 3), dtype="float32") =
R.grad.start_checkpoint(lv2)
- lv3: R.Tensor((3, 3), dtype="float32") = R.power(lv2_scp,
R.const(3, "float32"))
- lv4: R.Tensor((3, 3), dtype="float32") = R.power(lv3,
R.const(3, "float32"))
- gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None,
keepdims=False)
- gv_ecp: R.Tensor((), dtype="float32") =
R.grad.end_checkpoint(gv)
+ x_scp: R.Tensor((3, 3), "float32") = R.grad.start_checkpoint(x)
+ lv1: R.Tensor((3, 3), "float32") = R.power(x_scp, R.const(3,
"float32"))
+ lv1_ecp: R.Tensor((3, 3), "float32") =
R.grad.end_checkpoint(lv1)
+ lv2: R.Tensor((3, 3), "float32") = R.power(lv1_ecp, R.const(3,
"float32"))
+ lv2_scp: R.Tensor((3, 3), "float32") =
R.grad.start_checkpoint(lv2)
+ lv3: R.Tensor((3, 3), "float32") = R.power(lv2_scp, R.const(3,
"float32"))
+ lv4: R.Tensor((3, 3), "float32") = R.power(lv3, R.const(3,
"float32"))
+ gv: R.Tensor((), "float32") = R.sum(lv4, axis=None,
keepdims=False)
+ gv_ecp: R.Tensor((), "float32") = R.grad.end_checkpoint(gv)
R.output(gv_ecp)
return gv_ecp
# fmt: on
@@ -123,49 +124,49 @@ def test_sequential_consecutive():
@I.ir_module
class Expected:
@R.function
- def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) ->
R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3),
dtype="float32"))):
+ def main_adjoint(x: R.Tensor((3, 3), "float32")) ->
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
with R.dataflow():
- lv1: R.Tensor((3, 3), dtype="float32") = R.power(x, R.const(3,
"float32"))
- lv2: R.Tensor((3, 3), dtype="float32") = R.power(lv1,
R.const(3, "float32"))
- lv3: R.Tensor((3, 3), dtype="float32") = R.power(lv2,
R.const(3, "float32"))
- lv4: R.Tensor((3, 3), dtype="float32") = R.power(lv3,
R.const(3, "float32"))
- gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None,
keepdims=False)
- gv_adjoint: R.Tensor((), dtype="float32") =
R.ones(R.shape([]), dtype="float32")
- lv3_cp: R.Tensor((3, 3), dtype="float32") = R.power(lv2,
R.const(3, "float32"))
- lv4_adjoint: R.Tensor((3, 3), dtype="float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
- lv: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv4_adjoint, R.const(3, "float32"))
- lv1_1: R.Tensor((), dtype="float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
- lv2_1: R.Tensor((3, 3), dtype="float32") = R.power(lv3_cp,
lv1_1)
- lv3_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv, lv2_1)
- lv6: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv3_adjoint, R.const(3, "float32"))
- lv7: R.Tensor((), dtype="float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
- lv8: R.Tensor((3, 3), dtype="float32") = R.power(lv2, lv7)
- lv2_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv6, lv8)
- lv1_cp: R.Tensor((3, 3), dtype="float32") = R.power(x,
R.const(3, "float32"))
- lv12: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv2_adjoint, R.const(3, "float32"))
- lv13: R.Tensor((), dtype="float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
- lv14: R.Tensor((3, 3), dtype="float32") = R.power(lv1_cp, lv13)
- lv1_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv12, lv14)
- lv18: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv1_adjoint, R.const(3, "float32"))
- lv19: R.Tensor((), dtype="float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
- lv20: R.Tensor((3, 3), dtype="float32") = R.power(x, lv19)
- x_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv18, lv20)
- x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint
+ lv1: R.Tensor((3, 3), "float32") = R.power(x, R.const(3,
"float32"))
+ lv2: R.Tensor((3, 3), "float32") = R.power(lv1, R.const(3,
"float32"))
+ lv3: R.Tensor((3, 3), "float32") = R.power(lv2, R.const(3,
"float32"))
+ lv4: R.Tensor((3, 3), "float32") = R.power(lv3, R.const(3,
"float32"))
+ gv: R.Tensor((), "float32") = R.sum(lv4, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]),
"float32")
+ lv3_cp: R.Tensor((3, 3), "float32") = R.power(lv2, R.const(3,
"float32"))
+ lv4_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+ lv: R.Tensor((3, 3), "float32") = R.multiply(lv4_adjoint,
R.const(3, "float32"))
+ lv1_1: R.Tensor((), "float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
+ lv2_1: R.Tensor((3, 3), "float32") = R.power(lv3_cp, lv1_1)
+ lv3_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv,
lv2_1)
+ lv6: R.Tensor((3, 3), "float32") = R.multiply(lv3_adjoint,
R.const(3, "float32"))
+ lv7: R.Tensor((), "float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
+ lv8: R.Tensor((3, 3), "float32") = R.power(lv2, lv7)
+ lv2_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv6, lv8)
+ lv1_cp: R.Tensor((3, 3), "float32") = R.power(x, R.const(3,
"float32"))
+ lv12: R.Tensor((3, 3), "float32") = R.multiply(lv2_adjoint,
R.const(3, "float32"))
+ lv13: R.Tensor((), "float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
+ lv14: R.Tensor((3, 3), "float32") = R.power(lv1_cp, lv13)
+ lv1_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv12,
lv14)
+ lv18: R.Tensor((3, 3), "float32") = R.multiply(lv1_adjoint,
R.const(3, "float32"))
+ lv19: R.Tensor((), "float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
+ lv20: R.Tensor((3, 3), "float32") = R.power(x, lv19)
+ x_adjoint: R.Tensor((3, 3), "float32") = R.multiply(lv18, lv20)
+ x_adjoint_out: R.Tensor((3, 3), "float32") = x_adjoint
R.output(gv, x_adjoint_out)
return (gv, (x_adjoint_out,))
@R.function
- def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((),
dtype="float32"):
+ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
with R.dataflow():
- x_scp: R.Tensor((3, 3), dtype="float32") =
R.grad.start_checkpoint(x)
- lv1: R.Tensor((3, 3), dtype="float32") = R.power(x_scp,
R.const(3, "float32"))
- lv2: R.Tensor((3, 3), dtype="float32") = R.power(lv1,
R.const(3, "float32"))
- lv2_ecp: R.Tensor((3, 3), dtype="float32") =
R.grad.end_checkpoint(lv2)
- lv2_scp: R.Tensor((3, 3), dtype="float32") =
R.grad.start_checkpoint(lv2_ecp)
- lv3: R.Tensor((3, 3), dtype="float32") = R.power(lv2_scp,
R.const(3, "float32"))
- lv4: R.Tensor((3, 3), dtype="float32") = R.power(lv3,
R.const(3, "float32"))
- lv4_ecp: R.Tensor((3, 3), dtype="float32") =
R.grad.end_checkpoint(lv4)
- gv: R.Tensor((), dtype="float32") = R.sum(lv4_ecp, axis=None,
keepdims=False)
+ x_scp: R.Tensor((3, 3), "float32") = R.grad.start_checkpoint(x)
+ lv1: R.Tensor((3, 3), "float32") = R.power(x_scp, R.const(3,
"float32"))
+ lv2: R.Tensor((3, 3), "float32") = R.power(lv1, R.const(3,
"float32"))
+ lv2_ecp: R.Tensor((3, 3), "float32") =
R.grad.end_checkpoint(lv2)
+ lv2_scp: R.Tensor((3, 3), "float32") =
R.grad.start_checkpoint(lv2_ecp)
+ lv3: R.Tensor((3, 3), "float32") = R.power(lv2_scp, R.const(3,
"float32"))
+ lv4: R.Tensor((3, 3), "float32") = R.power(lv3, R.const(3,
"float32"))
+ lv4_ecp: R.Tensor((3, 3), "float32") =
R.grad.end_checkpoint(lv4)
+ gv: R.Tensor((), "float32") = R.sum(lv4_ecp, axis=None,
keepdims=False)
R.output(gv)
return gv
@@ -196,49 +197,49 @@ def test_tuple():
@I.ir_module
class Expected:
@R.function
- def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) ->
R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3),
dtype="float32"))):
+ def main_adjoint(x: R.Tensor((3, 3), "float32")) ->
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
with R.dataflow():
- lv1: R.Tensor((3, 3), dtype="float32") = R.power(x, R.const(3,
"float32"))
- lv2: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,
3), dtype="float32")) = x, lv1
- lv3: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,
3), dtype="float32")) = lv2
- lv4: R.Tensor((3, 3), dtype="float32") = lv3[0]
- lv4_1: R.Tensor((3, 3), dtype="float32") = R.power(lv4,
R.const(3, "float32"))
- gv: R.Tensor((), dtype="float32") = R.sum(lv4_1, axis=None,
keepdims=False)
- gv_adjoint: R.Tensor((), dtype="float32") =
R.ones(R.shape([]), dtype="float32")
- lv1_cp: R.Tensor((3, 3), dtype="float32") = R.power(x,
R.const(3, "float32"))
- lv2_cp: R.Tuple(R.Tensor((3, 3), dtype="float32"),
R.Tensor((3, 3), dtype="float32")) = x, lv1_cp
- lv3_cp: R.Tuple(R.Tensor((3, 3), dtype="float32"),
R.Tensor((3, 3), dtype="float32")) = lv2_cp
- lv4_cp: R.Tensor((3, 3), dtype="float32") = lv3_cp[0]
- lv4_adjoint: R.Tensor((3, 3), dtype="float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
- lv: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv4_adjoint, R.const(3, "float32"))
- lv1_1: R.Tensor((), dtype="float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
- lv2_1: R.Tensor((3, 3), dtype="float32") = R.power(lv4_cp,
lv1_1)
- lv4_adjoint1: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv, lv2_1)
- lv6: R.Tensor((3, 3), dtype="float32") = R.zeros(R.shape([3,
3]), dtype="float32")
- lv3_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"),
R.Tensor((3, 3), dtype="float32")) = lv4_adjoint1, lv6
- lv2_adjoint: R.Tuple(R.Tensor((3, 3), dtype="float32"),
R.Tensor((3, 3), dtype="float32")) = lv3_adjoint
- x_adjoint: R.Tensor((3, 3), dtype="float32") = lv2_adjoint[0]
- lv1_adjoint: R.Tensor((3, 3), dtype="float32") = lv2_adjoint[1]
- lv7: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv1_adjoint, R.const(3, "float32"))
- lv8: R.Tensor((), dtype="float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
- lv9: R.Tensor((3, 3), dtype="float32") = R.power(x, lv8)
- lv12: R.Tensor((3, 3), dtype="float32") = R.multiply(lv7, lv9)
- x_adjoint1: R.Tensor((3, 3), dtype="float32") =
R.add(x_adjoint, lv12)
- x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint1
+ lv1: R.Tensor((3, 3), "float32") = R.power(x, R.const(3,
"float32"))
+ lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = x, lv1
+ lv3: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = lv2
+ lv4: R.Tensor((3, 3), "float32") = lv3[0]
+ lv4_1: R.Tensor((3, 3), "float32") = R.power(lv4, R.const(3,
"float32"))
+ gv: R.Tensor((), "float32") = R.sum(lv4_1, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]),
"float32")
+ lv1_cp: R.Tensor((3, 3), "float32") = R.power(x, R.const(3,
"float32"))
+ lv2_cp: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = x, lv1_cp
+ lv3_cp: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = lv2_cp
+ lv4_cp: R.Tensor((3, 3), "float32") = lv3_cp[0]
+ lv4_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+ lv: R.Tensor((3, 3), "float32") = R.multiply(lv4_adjoint,
R.const(3, "float32"))
+ lv1_1: R.Tensor((), "float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
+ lv2_1: R.Tensor((3, 3), "float32") = R.power(lv4_cp, lv1_1)
+ lv4_adjoint1: R.Tensor((3, 3), "float32") = R.multiply(lv,
lv2_1)
+ lv6: R.Tensor((3, 3), "float32") = R.zeros(R.shape([3, 3]),
"float32")
+ lv3_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")) = lv4_adjoint1, lv6
+ lv2_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")) = lv3_adjoint
+ x_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint[0]
+ lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint[1]
+ lv7: R.Tensor((3, 3), "float32") = R.multiply(lv1_adjoint,
R.const(3, "float32"))
+ lv8: R.Tensor((), "float32") = R.subtract(R.const(3,
"float32"), R.const(1, "float32"))
+ lv9: R.Tensor((3, 3), "float32") = R.power(x, lv8)
+ lv12: R.Tensor((3, 3), "float32") = R.multiply(lv7, lv9)
+ x_adjoint1: R.Tensor((3, 3), "float32") = R.add(x_adjoint,
lv12)
+ x_adjoint_out: R.Tensor((3, 3), "float32") = x_adjoint1
R.output(gv, x_adjoint_out)
return (gv, (x_adjoint_out,))
@R.function
- def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((),
dtype="float32"):
+ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
with R.dataflow():
- x_scp: R.Tensor((3, 3), dtype="float32") =
R.grad.start_checkpoint(x)
- lv1: R.Tensor((3, 3), dtype="float32") = R.power(x_scp,
R.const(3, "float32"))
- lv2: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,
3), dtype="float32")) = x, lv1
- lv3: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3,
3), dtype="float32")) = lv2
- lv4: R.Tensor((3, 3), dtype="float32") = lv3[0]
- lv4_1: R.Tensor((3, 3), dtype="float32") = R.power(lv4,
R.const(3, "float32"))
- lv4_ecp: R.Tensor((3, 3), dtype="float32") =
R.grad.end_checkpoint(lv4_1)
- gv: R.Tensor((), dtype="float32") = R.sum(lv4_ecp, axis=None,
keepdims=False)
+ x_scp: R.Tensor((3, 3), "float32") = R.grad.start_checkpoint(x)
+ lv1: R.Tensor((3, 3), "float32") = R.power(x_scp, R.const(3,
"float32"))
+ lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = x, lv1
+ lv3: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = lv2
+ lv4: R.Tensor((3, 3), "float32") = lv3[0]
+ lv4_1: R.Tensor((3, 3), "float32") = R.power(lv4, R.const(3,
"float32"))
+ lv4_ecp: R.Tensor((3, 3), "float32") =
R.grad.end_checkpoint(lv4_1)
+ gv: R.Tensor((), "float32") = R.sum(lv4_ecp, axis=None,
keepdims=False)
R.output(gv)
return gv
# fmt: on
@@ -272,35 +273,35 @@ def test_tree():
@I.ir_module
class Expected1:
@R.function
- def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3,
3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3),
dtype="float32"), v: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((),
dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32"), R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32"), R.Tensor((3, 3), dtype="float32"))):
+ def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32"), v:
R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"),
R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))):
with R.dataflow():
- lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(x, y)
- lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, z)
- lv3: R.Tensor((3, 3), dtype="float32") = R.multiply(u, v)
- lv4: R.Tensor((3, 3), dtype="float32") = R.multiply(lv2, lv3)
- gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None,
keepdims=False)
- gv_adjoint: R.Tensor((), dtype="float32") =
R.ones(R.shape([]), dtype="float32")
- lv4_adjoint: R.Tensor((3, 3), dtype="float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
- lv2_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, z)
- lv3_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(u, v)
- lv2_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv4_adjoint, lv3_cp)
- lv3_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv4_adjoint, lv2_cp)
- u_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv3_adjoint, v)
- v_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv3_adjoint, u)
- lv1_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv2_adjoint, z)
- z_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv2_adjoint, lv1)
- x_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv1_adjoint, y)
- y_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv1_adjoint, x)
- x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint
- y_adjoint_out: R.Tensor((3, 3), dtype="float32") = y_adjoint
- z_adjoint_out: R.Tensor((3, 3), dtype="float32") = z_adjoint
- u_adjoint_out: R.Tensor((3, 3), dtype="float32") = u_adjoint
- v_adjoint_out: R.Tensor((3, 3), dtype="float32") = v_adjoint
+ lv1: R.Tensor((3, 3), "float32") = R.multiply(x, y)
+ lv2: R.Tensor((3, 3), "float32") = R.multiply(lv1, z)
+ lv3: R.Tensor((3, 3), "float32") = R.multiply(u, v)
+ lv4: R.Tensor((3, 3), "float32") = R.multiply(lv2, lv3)
+ gv: R.Tensor((), "float32") = R.sum(lv4, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]),
"float32")
+ lv4_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+ lv2_cp: R.Tensor((3, 3), "float32") = R.multiply(lv1, z)
+ lv3_cp: R.Tensor((3, 3), "float32") = R.multiply(u, v)
+ lv2_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv4_adjoint, lv3_cp)
+ lv3_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv4_adjoint, lv2_cp)
+ u_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv3_adjoint, v)
+ v_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv3_adjoint, u)
+ lv1_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv2_adjoint, z)
+ z_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv2_adjoint, lv1)
+ x_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv1_adjoint, y)
+ y_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv1_adjoint, x)
+ x_adjoint_out: R.Tensor((3, 3), "float32") = x_adjoint
+ y_adjoint_out: R.Tensor((3, 3), "float32") = y_adjoint
+ z_adjoint_out: R.Tensor((3, 3), "float32") = z_adjoint
+ u_adjoint_out: R.Tensor((3, 3), "float32") = u_adjoint
+ v_adjoint_out: R.Tensor((3, 3), "float32") = v_adjoint
R.output(gv, x_adjoint_out, y_adjoint_out, z_adjoint_out,
u_adjoint_out, v_adjoint_out)
return (gv, (x_adjoint_out, y_adjoint_out, z_adjoint_out,
u_adjoint_out, v_adjoint_out))
@R.function
- def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3),
dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3),
dtype="float32"), v: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((),
dtype="float32"):
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32"), v:
R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
with R.dataflow():
lv1 = x * y
lv1_scp = R.grad.start_checkpoint(lv1)
@@ -324,24 +325,24 @@ def test_tree():
@I.ir_module
class Expected2:
@R.function
- def main_adjoint(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3,
3), dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3),
dtype="float32"), v: R.Tensor((3, 3), dtype="float32")) -> R.Tuple(R.Tensor((),
dtype="float32"), R.Tuple(R.Tensor((3, 3), dtype="float32"))):
+ def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32"), v:
R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"),
R.Tuple(R.Tensor((3, 3), "float32"))):
with R.dataflow():
- lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(x, y)
- lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1, z)
- lv3: R.Tensor((3, 3), dtype="float32") = R.multiply(u, v)
- lv4: R.Tensor((3, 3), dtype="float32") = R.multiply(lv2, lv3)
- gv: R.Tensor((), dtype="float32") = R.sum(lv4, axis=None,
keepdims=False)
- gv_adjoint: R.Tensor((), dtype="float32") =
R.ones(R.shape([]), dtype="float32")
- lv4_adjoint: R.Tensor((3, 3), dtype="float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
- lv3_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(u, v)
- lv2_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv4_adjoint, lv3_cp)
- z_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv2_adjoint, lv1)
- z_adjoint_out: R.Tensor((3, 3), dtype="float32") = z_adjoint
+ lv1: R.Tensor((3, 3), "float32") = R.multiply(x, y)
+ lv2: R.Tensor((3, 3), "float32") = R.multiply(lv1, z)
+ lv3: R.Tensor((3, 3), "float32") = R.multiply(u, v)
+ lv4: R.Tensor((3, 3), "float32") = R.multiply(lv2, lv3)
+ gv: R.Tensor((), "float32") = R.sum(lv4, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]),
"float32")
+ lv4_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+ lv3_cp: R.Tensor((3, 3), "float32") = R.multiply(u, v)
+ lv2_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv4_adjoint, lv3_cp)
+ z_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv2_adjoint, lv1)
+ z_adjoint_out: R.Tensor((3, 3), "float32") = z_adjoint
R.output(gv, z_adjoint_out)
return (gv, (z_adjoint_out,))
@R.function
- def main(x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3),
dtype="float32"), z: R.Tensor((3, 3), dtype="float32"), u: R.Tensor((3, 3),
dtype="float32"), v: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((),
dtype="float32"):
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32"), v:
R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
with R.dataflow():
lv1 = x * y
lv1_scp = R.grad.start_checkpoint(lv1)
@@ -387,54 +388,54 @@ def test_dag():
lv12 = R.multiply(lv11, R.const(2, "float32"))
lv13 = R.grad.end_checkpoint(lv12)
lv14 = R.multiply(lv9, lv13)
- gv: R.Tensor((), dtype="float32") = R.sum(lv14, axis=None,
keepdims=False)
+ gv: R.Tensor((), "float32") = R.sum(lv14, axis=None,
keepdims=False)
R.output(gv)
return gv
@I.ir_module
class Expected:
@R.function
- def main_adjoint(x: R.Tensor((3, 3), dtype="float32")) ->
R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 3),
dtype="float32"))):
+ def main_adjoint(x: R.Tensor((3, 3), "float32")) ->
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
with R.dataflow():
- lv1: R.Tensor((3, 3), dtype="float32") = R.multiply(x,
R.const(2, "float32"))
- lv2: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1,
R.const(2, "float32"))
- lv3: R.Tensor((3, 3), dtype="float32") = R.multiply(x, lv2)
- lv4: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3,
R.const(2, "float32"))
- lv5: R.Tensor((3, 3), dtype="float32") = R.multiply(lv4,
R.const(2, "float32"))
- lv6: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3, lv5)
- lv7: R.Tensor((3, 3), dtype="float32") = R.multiply(lv6,
R.const(2, "float32"))
- lv8: R.Tensor((3, 3), dtype="float32") = R.multiply(lv7,
R.const(2, "float32"))
- lv9: R.Tensor((3, 3), dtype="float32") = R.multiply(lv6, lv8)
- gv: R.Tensor((), dtype="float32") = R.sum(lv9, axis=None,
keepdims=False)
- gv_adjoint: R.Tensor((), dtype="float32") =
R.ones(R.shape([]), dtype="float32")
- lv9_adjoint: R.Tensor((3, 3), dtype="float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
- lv7_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv6,
R.const(2, "float32"))
- lv8_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv7_cp,
R.const(2, "float32"))
- lv6_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv9_adjoint, lv8_cp)
- lv8_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv9_adjoint, lv6)
- lv7_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv8_adjoint, R.const(2, "float32"))
- lv1_1: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv7_adjoint, R.const(2, "float32"))
- lv6_adjoint1: R.Tensor((3, 3), dtype="float32") =
R.add(lv6_adjoint, lv1_1)
- lv4_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv3,
R.const(2, "float32"))
- lv5_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv4_cp,
R.const(2, "float32"))
- lv3_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv6_adjoint1, lv5_cp)
- lv5_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv6_adjoint1, lv3)
- lv4_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv5_adjoint, R.const(2, "float32"))
- lv4_1: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv4_adjoint, R.const(2, "float32"))
- lv3_adjoint1: R.Tensor((3, 3), dtype="float32") =
R.add(lv3_adjoint, lv4_1)
- lv1_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(x,
R.const(2, "float32"))
- lv2_cp: R.Tensor((3, 3), dtype="float32") = R.multiply(lv1_cp,
R.const(2, "float32"))
- x_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv3_adjoint1, lv2_cp)
- lv2_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv3_adjoint1, x)
- lv1_adjoint: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv2_adjoint, R.const(2, "float32"))
- lv7_1: R.Tensor((3, 3), dtype="float32") =
R.multiply(lv1_adjoint, R.const(2, "float32"))
- x_adjoint1: R.Tensor((3, 3), dtype="float32") =
R.add(x_adjoint, lv7_1)
- x_adjoint_out: R.Tensor((3, 3), dtype="float32") = x_adjoint1
+ lv1: R.Tensor((3, 3), "float32") = R.multiply(x, R.const(2,
"float32"))
+ lv2: R.Tensor((3, 3), "float32") = R.multiply(lv1, R.const(2,
"float32"))
+ lv3: R.Tensor((3, 3), "float32") = R.multiply(x, lv2)
+ lv4: R.Tensor((3, 3), "float32") = R.multiply(lv3, R.const(2,
"float32"))
+ lv5: R.Tensor((3, 3), "float32") = R.multiply(lv4, R.const(2,
"float32"))
+ lv6: R.Tensor((3, 3), "float32") = R.multiply(lv3, lv5)
+ lv7: R.Tensor((3, 3), "float32") = R.multiply(lv6, R.const(2,
"float32"))
+ lv8: R.Tensor((3, 3), "float32") = R.multiply(lv7, R.const(2,
"float32"))
+ lv9: R.Tensor((3, 3), "float32") = R.multiply(lv6, lv8)
+ gv: R.Tensor((), "float32") = R.sum(lv9, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]),
"float32")
+ lv9_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+ lv7_cp: R.Tensor((3, 3), "float32") = R.multiply(lv6,
R.const(2, "float32"))
+ lv8_cp: R.Tensor((3, 3), "float32") = R.multiply(lv7_cp,
R.const(2, "float32"))
+ lv6_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv9_adjoint, lv8_cp)
+ lv8_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv9_adjoint, lv6)
+ lv7_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv8_adjoint, R.const(2, "float32"))
+ lv1_1: R.Tensor((3, 3), "float32") = R.multiply(lv7_adjoint,
R.const(2, "float32"))
+ lv6_adjoint1: R.Tensor((3, 3), "float32") = R.add(lv6_adjoint,
lv1_1)
+ lv4_cp: R.Tensor((3, 3), "float32") = R.multiply(lv3,
R.const(2, "float32"))
+ lv5_cp: R.Tensor((3, 3), "float32") = R.multiply(lv4_cp,
R.const(2, "float32"))
+ lv3_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv6_adjoint1, lv5_cp)
+ lv5_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv6_adjoint1, lv3)
+ lv4_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv5_adjoint, R.const(2, "float32"))
+ lv4_1: R.Tensor((3, 3), "float32") = R.multiply(lv4_adjoint,
R.const(2, "float32"))
+ lv3_adjoint1: R.Tensor((3, 3), "float32") = R.add(lv3_adjoint,
lv4_1)
+ lv1_cp: R.Tensor((3, 3), "float32") = R.multiply(x, R.const(2,
"float32"))
+ lv2_cp: R.Tensor((3, 3), "float32") = R.multiply(lv1_cp,
R.const(2, "float32"))
+ x_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv3_adjoint1, lv2_cp)
+ lv2_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv3_adjoint1, x)
+ lv1_adjoint: R.Tensor((3, 3), "float32") =
R.multiply(lv2_adjoint, R.const(2, "float32"))
+ lv7_1: R.Tensor((3, 3), "float32") = R.multiply(lv1_adjoint,
R.const(2, "float32"))
+ x_adjoint1: R.Tensor((3, 3), "float32") = R.add(x_adjoint,
lv7_1)
+ x_adjoint_out: R.Tensor((3, 3), "float32") = x_adjoint1
R.output(gv, x_adjoint_out)
return (gv, (x_adjoint_out,))
@R.function
- def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((),
dtype="float32"):
+ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
with R.dataflow():
lv = R.grad.start_checkpoint(x)
lv1 = R.multiply(lv, R.const(2, "float32"))
@@ -451,7 +452,7 @@ def test_dag():
lv12 = R.multiply(lv11, R.const(2, "float32"))
lv13 = R.grad.end_checkpoint(lv12)
lv14 = R.multiply(lv9, lv13)
- gv: R.Tensor((), dtype="float32") = R.sum(lv14, axis=None,
keepdims=False)
+ gv: R.Tensor((), "float32") = R.sum(lv14, axis=None,
keepdims=False)
R.output(gv)
return gv
# fmt: on
@@ -474,9 +475,9 @@ def test_checkpoint_api():
x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
with bb.function("main", [x]):
with bb.dataflow():
- lv1 = bb.emit(checkpoint(func1, x))
+ lv1 = bb.emit(nn.checkpoint(func1, x))
lv2 = bb.emit(relax.op.power(lv1, relax.const(3, "float32")))
- lv3 = bb.emit_output(checkpoint(func2, lv2))
+ lv3 = bb.emit_output(nn.checkpoint(func2, lv2))
bb.emit_func_output(lv3)
# fmt: off
@@ -516,7 +517,7 @@ def test_checkpoint_tree():
with bb.function("main", [x, y, z, u, v]):
with bb.dataflow():
lv1 = bb.emit(x * y)
- cp = checkpoint(func, lv1, z, u, v)
+ cp = nn.checkpoint(func, lv1, z, u, v)
lv2 = bb.emit(cp[0])
lv3 = bb.emit(cp[1])
lv4 = bb.emit(lv2 * lv3)
@@ -559,11 +560,11 @@ def test_checkpoint_dag():
x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
with bb.function("main", [x]):
with bb.dataflow():
- lv1 = bb.emit(checkpoint(func, x))
+ lv1 = bb.emit(nn.checkpoint(func, x))
lv2 = bb.emit(x * lv1)
- lv3 = bb.emit(checkpoint(func, lv2))
+ lv3 = bb.emit(nn.checkpoint(func, lv2))
lv4 = bb.emit(lv2 * lv3)
- lv5 = bb.emit(checkpoint(func, lv4))
+ lv5 = bb.emit(nn.checkpoint(func, lv4))
lv6 = bb.emit(lv4 * lv5)
gv = bb.emit_output(relax.op.sum(lv6))
bb.emit_func_output(gv)
@@ -572,7 +573,7 @@ def test_checkpoint_dag():
@I.ir_module
class Expected:
@R.function
- def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((),
dtype="float32"):
+ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
with R.dataflow():
lv = R.grad.start_checkpoint(x)
lv1 = R.multiply(lv, R.const(2, "float32"))
@@ -589,7 +590,7 @@ def test_checkpoint_dag():
lv12 = R.multiply(lv11, R.const(2, "float32"))
lv13 = R.grad.end_checkpoint(lv12)
lv14 = R.multiply(lv9, lv13)
- gv: R.Tensor((), dtype="float32") = R.sum(lv14, axis=None,
keepdims=False)
+ gv: R.Tensor((), "float32") = R.sum(lv14, axis=None,
keepdims=False)
R.output(gv)
return gv
# fmt: on
@@ -605,8 +606,8 @@ def test_checkpoint_sequential():
x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
with bb.function("main", [x]):
with bb.dataflow():
- lv1 = emit_checkpoint_sequential([func] * 5, 2, x)
- lv2 = emit_checkpoint_sequential([func] * 4, 2, lv1)
+ lv1 = nn.emit_checkpoint_sequential([func] * 5, 2, x)
+ lv2 = nn.emit_checkpoint_sequential([func] * 4, 2, lv1)
gv = bb.emit_output(lv2)
bb.emit_func_output(gv)
@@ -614,24 +615,24 @@ def test_checkpoint_sequential():
@I.ir_module
class Expected:
@R.function
- def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3),
dtype="float32"):
+ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((3, 3),
"float32"):
with R.dataflow():
- x_scp: R.Tensor((3, 3), dtype="float32") =
R.grad.start_checkpoint(x)
- lv: R.Tensor((3, 3), dtype="float32") = R.add(x_scp, x_scp)
- lv1: R.Tensor((3, 3), dtype="float32") = R.add(lv, lv)
- lv1_ecp: R.Tensor((3, 3), dtype="float32") =
R.grad.end_checkpoint(lv1)
- lv1_ecp_scp: R.Tensor((3, 3), dtype="float32") =
R.grad.start_checkpoint(lv1_ecp)
- lv2: R.Tensor((3, 3), dtype="float32") = R.add(lv1_ecp_scp,
lv1_ecp_scp)
- lv3: R.Tensor((3, 3), dtype="float32") = R.add(lv2, lv2)
- lv3_ecp: R.Tensor((3, 3), dtype="float32") =
R.grad.end_checkpoint(lv3)
- lv4: R.Tensor((3, 3), dtype="float32") = R.add(lv3_ecp,
lv3_ecp)
- lv4_scp: R.Tensor((3, 3), dtype="float32") =
R.grad.start_checkpoint(lv4)
- lv5: R.Tensor((3, 3), dtype="float32") = R.add(lv4_scp,
lv4_scp)
- lv6: R.Tensor((3, 3), dtype="float32") = R.add(lv5, lv5)
- lv6_ecp: R.Tensor((3, 3), dtype="float32") =
R.grad.end_checkpoint(lv6)
- lv7: R.Tensor((3, 3), dtype="float32") = R.add(lv6_ecp,
lv6_ecp)
- lv8: R.Tensor((3, 3), dtype="float32") = R.add(lv7, lv7)
- gv: R.Tensor((3, 3), dtype="float32") = lv8
+ x_scp: R.Tensor((3, 3), "float32") = R.grad.start_checkpoint(x)
+ lv: R.Tensor((3, 3), "float32") = R.add(x_scp, x_scp)
+ lv1: R.Tensor((3, 3), "float32") = R.add(lv, lv)
+ lv1_ecp: R.Tensor((3, 3), "float32") =
R.grad.end_checkpoint(lv1)
+ lv1_ecp_scp: R.Tensor((3, 3), "float32") =
R.grad.start_checkpoint(lv1_ecp)
+ lv2: R.Tensor((3, 3), "float32") = R.add(lv1_ecp_scp,
lv1_ecp_scp)
+ lv3: R.Tensor((3, 3), "float32") = R.add(lv2, lv2)
+ lv3_ecp: R.Tensor((3, 3), "float32") =
R.grad.end_checkpoint(lv3)
+ lv4: R.Tensor((3, 3), "float32") = R.add(lv3_ecp, lv3_ecp)
+ lv4_scp: R.Tensor((3, 3), "float32") =
R.grad.start_checkpoint(lv4)
+ lv5: R.Tensor((3, 3), "float32") = R.add(lv4_scp, lv4_scp)
+ lv6: R.Tensor((3, 3), "float32") = R.add(lv5, lv5)
+ lv6_ecp: R.Tensor((3, 3), "float32") =
R.grad.end_checkpoint(lv6)
+ lv7: R.Tensor((3, 3), "float32") = R.add(lv6_ecp, lv6_ecp)
+ lv8: R.Tensor((3, 3), "float32") = R.add(lv7, lv7)
+ gv: R.Tensor((3, 3), "float32") = lv8
R.output(gv)
return gv
# fmt: on
@@ -647,8 +648,8 @@ def test_checkpoint_sequential_checkpoint_last():
x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
with bb.function("main", [x]):
with bb.dataflow():
- lv1 = emit_checkpoint_sequential([func] * 5, 2, x,
checkpoint_last=True)
- lv2 = emit_checkpoint_sequential([func] * 4, 2, lv1,
checkpoint_last=True)
+ lv1 = nn.emit_checkpoint_sequential([func] * 5, 2, x,
checkpoint_last=True)
+ lv2 = nn.emit_checkpoint_sequential([func] * 4, 2, lv1,
checkpoint_last=True)
gv = bb.emit_output(lv2)
bb.emit_func_output(gv)
@@ -656,28 +657,28 @@ def test_checkpoint_sequential_checkpoint_last():
@I.ir_module
class Expected:
@R.function
- def main(x: R.Tensor((3, 3), dtype="float32")) -> R.Tensor((3, 3),
dtype="float32"):
+ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((3, 3),
"float32"):
with R.dataflow():
- x_scp: R.Tensor((3, 3), dtype="float32") =
R.grad.start_checkpoint(x)
- lv: R.Tensor((3, 3), dtype="float32") = R.add(x_scp, x_scp)
- lv1: R.Tensor((3, 3), dtype="float32") = R.add(lv, lv)
- lv1_ecp: R.Tensor((3, 3), dtype="float32") =
R.grad.end_checkpoint(lv1)
- lv1_ecp_scp: R.Tensor((3, 3), dtype="float32") =
R.grad.start_checkpoint(lv1_ecp)
- lv2: R.Tensor((3, 3), dtype="float32") = R.add(lv1_ecp_scp,
lv1_ecp_scp)
- lv3: R.Tensor((3, 3), dtype="float32") = R.add(lv2, lv2)
- lv3_ecp: R.Tensor((3, 3), dtype="float32") =
R.grad.end_checkpoint(lv3)
- lv3_ecp_scp: R.Tensor((3, 3), dtype="float32") =
R.grad.start_checkpoint(lv3_ecp)
- lv4: R.Tensor((3, 3), dtype="float32") = R.add(lv3_ecp_scp,
lv3_ecp_scp)
- lv4_ecp: R.Tensor((3, 3), dtype="float32") =
R.grad.end_checkpoint(lv4)
- lv4_ecp_scp: R.Tensor((3, 3), dtype="float32") =
R.grad.start_checkpoint(lv4_ecp)
- lv5: R.Tensor((3, 3), dtype="float32") = R.add(lv4_ecp_scp,
lv4_ecp_scp)
- lv6: R.Tensor((3, 3), dtype="float32") = R.add(lv5, lv5)
- lv6_ecp: R.Tensor((3, 3), dtype="float32") =
R.grad.end_checkpoint(lv6)
- lv6_ecp_scp: R.Tensor((3, 3), dtype="float32") =
R.grad.start_checkpoint(lv6_ecp)
- lv7: R.Tensor((3, 3), dtype="float32") = R.add(lv6_ecp_scp,
lv6_ecp_scp)
- lv8: R.Tensor((3, 3), dtype="float32") = R.add(lv7, lv7)
- lv8_ecp: R.Tensor((3, 3), dtype="float32") =
R.grad.end_checkpoint(lv8)
- gv: R.Tensor((3, 3), dtype="float32") = lv8_ecp
+ x_scp: R.Tensor((3, 3), "float32") = R.grad.start_checkpoint(x)
+ lv: R.Tensor((3, 3), "float32") = R.add(x_scp, x_scp)
+ lv1: R.Tensor((3, 3), "float32") = R.add(lv, lv)
+ lv1_ecp: R.Tensor((3, 3), "float32") =
R.grad.end_checkpoint(lv1)
+ lv1_ecp_scp: R.Tensor((3, 3), "float32") =
R.grad.start_checkpoint(lv1_ecp)
+ lv2: R.Tensor((3, 3), "float32") = R.add(lv1_ecp_scp,
lv1_ecp_scp)
+ lv3: R.Tensor((3, 3), "float32") = R.add(lv2, lv2)
+ lv3_ecp: R.Tensor((3, 3), "float32") =
R.grad.end_checkpoint(lv3)
+ lv3_ecp_scp: R.Tensor((3, 3), "float32") =
R.grad.start_checkpoint(lv3_ecp)
+ lv4: R.Tensor((3, 3), "float32") = R.add(lv3_ecp_scp,
lv3_ecp_scp)
+ lv4_ecp: R.Tensor((3, 3), "float32") =
R.grad.end_checkpoint(lv4)
+ lv4_ecp_scp: R.Tensor((3, 3), "float32") =
R.grad.start_checkpoint(lv4_ecp)
+ lv5: R.Tensor((3, 3), "float32") = R.add(lv4_ecp_scp,
lv4_ecp_scp)
+ lv6: R.Tensor((3, 3), "float32") = R.add(lv5, lv5)
+ lv6_ecp: R.Tensor((3, 3), "float32") =
R.grad.end_checkpoint(lv6)
+ lv6_ecp_scp: R.Tensor((3, 3), "float32") =
R.grad.start_checkpoint(lv6_ecp)
+ lv7: R.Tensor((3, 3), "float32") = R.add(lv6_ecp_scp,
lv6_ecp_scp)
+ lv8: R.Tensor((3, 3), "float32") = R.add(lv7, lv7)
+ lv8_ecp: R.Tensor((3, 3), "float32") =
R.grad.end_checkpoint(lv8)
+ gv: R.Tensor((3, 3), "float32") = lv8_ecp
R.output(gv)
return gv
# fmt: on
@@ -685,5 +686,70 @@ def test_checkpoint_sequential_checkpoint_last():
assert_structural_equal(bb.get(), Expected)
+def test_checkpoint_dag():
+ """Comp. graph is a DAG with only one output. Here we only test the simple
case: comp. graph
+ is a sequence of sub-graphs, and the checkpoints are the intersections of
connected
+ subgraphs."""
+
+ def func(x):
+ return x * relax.const(2, "float32") * relax.const(2, "float32")
+
+ bb = BlockBuilder()
+ x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv1 = bb.emit(nn.checkpoint(func, x))
+ lv2 = bb.emit(x * lv1)
+ lv3 = bb.emit(nn.checkpoint(func, lv2))
+ lv4 = bb.emit(lv2 * lv3)
+ lv5 = bb.emit(nn.checkpoint(func, lv4))
+ lv6 = bb.emit(lv4 * lv5)
+ gv = bb.emit_output(relax.op.sum(lv6))
+ bb.emit_func_output(gv)
+
+
+def test_checkpoint_with_intermediate_require_grads():
+ def func(x):
+ return x * x * x
+
+ bb = BlockBuilder()
+ x = relax.Var("x", relax.TensorStructInfo((3, 3), "float32"))
+ with bb.function("main", [x]):
+ with bb.dataflow():
+ lv1 = nn.emit_checkpoint(func, x)
+ gv = bb.emit_output(relax.op.sum(lv1))
+ bb.emit_func_output(gv)
+
+ # fmt: off
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 3), "float32")) ->
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
+ with R.dataflow():
+ lv: R.Tensor((3, 3), "float32") = R.multiply(x, x)
+ lv1: R.Tensor((3, 3), "float32") = R.multiply(lv, x)
+ gv: R.Tensor((), "float32") = R.sum(lv1, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones(R.shape([]),
"float32")
+ lv1_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 3]))
+ lv1_adjoint_out: R.Tensor((3, 3), "float32") = lv1_adjoint
+ R.output(gv, lv1_adjoint_out)
+ return (gv, (lv1_adjoint_out,))
+
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
+ with R.dataflow():
+ x_scp: R.Tensor((3, 3), "float32") = R.grad.start_checkpoint(x)
+ lv: R.Tensor((3, 3), "float32") = R.multiply(x_scp, x_scp)
+ lv1: R.Tensor((3, 3), "float32") = R.multiply(lv, x_scp)
+ lv1_ecp: R.Tensor((3, 3), "float32") =
R.grad.end_checkpoint(lv1)
+ gv: R.Tensor((), "float32") = R.sum(lv1_ecp, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+ # fmt: on
+
+ After = relax.transform.Gradient("main", lv1)(bb.get())
+ assert_structural_equal(After, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()