This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0c447d6f9c [Unity][Transform] High-level reverse-mode automatic
differentiation pass (#14542)
0c447d6f9c is described below
commit 0c447d6f9c85cb1f2b3f246a1c58cb6a349bf600
Author: Chaofan Lin <[email protected]>
AuthorDate: Sun Apr 9 20:13:21 2023 +0800
[Unity][Transform] High-level reverse-mode automatic differentiation pass
(#14542)
### Introduction
This PR introduces the high-level reverse-mode automatic differentiation
pass `Gradient` for Relax. It's the core component when we are trying training
or fine-tuning in Relax IR.
Before upstreaming, this work is actively iterated and maintained in many
forks like [mlc](https://github.com/mlc-ai/relax) and
[relax-training](https://github.com/ACMClass-TVM-20/relax-training). Now it
reaches a relatively stable version and it's time for us to upstream this
important work to the unity branch.
The Python side API:
- `Gradient(func_name: str, require_grads: Optional[Union[Var, List[Var]]]
= None, target_index: int = 0) -> tvm.ir.transform.Pass`
It will transform the given funcion in the IRModule, and adds a new
function that calculates the gradient with regard to the function's output.
### Examples
```
@I.ir_module
class Module:
@R.function
def main(
x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3),
dtype="float32")
) -> R.Tensor((), dtype="float32"):
with R.dataflow():
lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
# use R.sum to reduce the tensor to a scalar
lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None,
keepdims=False)
R.output(lv2)
return lv2
After = relax.transform.Gradient("main")(Module)
```
Then the transformed module `After` will be
```
@I.ir_module
class After:
@R.function
def main(
x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3),
dtype="float32")
) -> R.Tensor((), dtype="float32"):
with R.dataflow():
lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None,
keepdims=False)
R.output(lv2)
return lv2
@R.function
def main_adjoint(
x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3),
dtype="float32")
) -> R.Tuple(
R.Tensor((), dtype="float32"),
R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")),
):
with R.dataflow():
# original bindings
lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None,
keepdims=False)
# bindings w.r.t. intermediate variables
lv2_adjoint: R.Tensor((), dtype="float32") = R.ones((),
dtype="float32")
lv1_adjoint: R.Tensor((3, 3), dtype="float32") = R.broadcast_to(
lv2_adjoint, (3, 3)
)
# bindings w.r.t. parameters
x_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint
y_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint
R.output(lv2, x_adjoint, y_adjoint)
# return value: (orig_return_values, tuple(adjoints))
return (lv2, (x_adjoint, y_adjoint))
```
Here we specify the target function `main` by its name.
We let the `require_grads` be default value (`None`) so it will calculate
all inputs' adjoints (`x_adjoint`, `y_adjoint`) and return them.
We let the `target_index` be default value `0` so it will take the unique
return value `lv2` as the target (the scalar we start to differentiate and
propagating adjoints) of AD.
Co-authored-by: Yixin Dong <[email protected]>
---
include/tvm/relax/transform.h | 26 +
python/tvm/relax/transform/transform.py | 170 +++
src/relax/transform/gradient.cc | 469 ++++++++
src/relax/transform/utils.cc | 13 +
src/relax/transform/utils.h | 16 +
tests/python/relax/test_transform_gradient.py | 1164 ++++++++++++++++++++
.../relax/test_transform_gradient_numeric.py | 192 ++++
7 files changed, 2050 insertions(+)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 5615a0951f..fec2ef0a04 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -348,6 +348,32 @@ class PatternCheckContext : public ObjectRef {
PatternCheckContextNode);
};
+/*!
+ * \brief Reverse-mode automatic differentiation.
+ *
+ * This pass will differentiate one function in the IRModule. Now the input
function must have only
+ * one dataflow block.
+ *
+ * For a given function specified by `func_name`, it generates a new function
with the name
+ * `func_name + "_adjoint"`. The new function computes the gradient of the
**differentiation
+ * target** with respect to the arguments specified by `require_grads` of the
original function.
+ *
+ * If the function has only one return value, the return value will be
specified as target. If the
+ * function has more than one return values, the target will be specified as
the target_index-th
+ * return value. The target must be a scalar (0-dim tensor).
+ *
+ * \param func_name The name of the specified function.
+ * \param require_grads The relax variables whose adjoints is needed. Must be
parameters of the
+ * given function and should not be duplicate. If it is not specified,
adjoints of all parameters
+ * would be computed.
+ * \param target_index If the specified function has more than one return
values, specify the index
+ * of the return value as the target. If it is not specified, the first return
value will be the
+ * target.
+ * \return The Pass.
+ */
+TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads =
NullOpt,
+ int target_index = 0);
+
/*!
* \brief Apply pattern matching to each function in the given module, and
group matched
* expressions into a new function. The end result is similar to FuseOps, but
fusion is driven
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 83f2f1bba8..a53c45b655 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -46,6 +46,176 @@ class DataflowBlockPass(tvm.ir.transform.Pass):
"""A pass that works on each tvm.relax.DataflowBlock in a module."""
+def Gradient(
+ func_name: str, require_grads: Optional[Union[Var, List[Var]]] = None,
target_index: int = 0
+) -> tvm.ir.transform.Pass:
+ """Reverse-mode automatic differentiation.
+
+ This pass will differentiate one function in the IRModule. Now the input
function must have only
+ one dataflow block.
+
+ For a given function specified by `func_name`, it generates a new function
with the name
+ `func_name + "_adjoint"`. The new function computes the gradient of the
**differentiation
+ target** with respect to the arguments specified by `require_grads` of the
original function.
+
+ If the function has only one return value, the return value will be
specified as target. If the
+ function has more than one return values, the target will be specified as
the target_index-th
+ return value. The target must be a scalar (0-dim tensor).
+
+ The new function will be like:
+
+ .. code-block:: python
+ @R.function
+ def main_adjoint(original_parameters):
+ with R.dataflow():
+ # the bindings of the original function
+ ...
+ # calculating the gradients
+ ...
+ R.output(original_outputs, grad_1, grad_2, ...)
+ return (original_return_value, (grad_1, grad_2, ...))
+
+ Parameters
+ ----------
+ func_name : str
+ The name of the specific function.
+
+ require_grads : Optional[Union[relax.Var, List[relax.Var]]]
+ The relax variables whose adjoints is needed. Must be parameters of
the given function and
+ should not be duplicate. If it is not specified, adjoints of all
parameters would be
+ computed.
+
+ target_index : int
+ If the specified function has more than one return values, specify the
index of the return
+ value as the target. If it is not specified, the first return value
will be the target.
+
+ Returns
+ -------
+ ret : tvm.ir.transform.Pass
+ The Pass.
+
+ Examples
+ --------
+ The following code shows how to use this pass:
+
+ .. code-block:: python
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(
+ x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3),
dtype="float32")
+ ) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
+ # use R.sum to reduce the tensor to a scalar
+ lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None,
keepdims=False)
+ R.output(lv2)
+ return lv2
+
+ After = relax.transform.Gradient("main")(Module)
+
+ The module after the Gradient pass will be:
+
+ .. code-block:: python
+
+ @I.ir_module
+ class After:
+ @R.function
+ def main(
+ x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3),
dtype="float32")
+ ) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
+ lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None,
keepdims=False)
+ R.output(lv2)
+ return lv2
+
+ @R.function
+ def main_adjoint(
+ x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3),
dtype="float32")
+ ) -> R.Tuple(
+ R.Tensor((), dtype="float32"),
+ R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")),
+ ):
+ with R.dataflow():
+ # original bindings
+ lv1: R.Tensor((3, 3), dtype="float32") = R.add(x, y)
+ lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None,
keepdims=False)
+ # bindings w.r.t. intermediate variables
+ lv2_adjoint: R.Tensor((), dtype="float32") = R.ones((),
dtype="float32")
+ lv1_adjoint: R.Tensor((3, 3), dtype="float32") =
R.broadcast_to(
+ lv2_adjoint, (3, 3)
+ )
+ # bindings w.r.t. parameters
+ x_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint
+ y_adjoint: R.Tensor((3, 3), dtype="float32") = lv1_adjoint
+ R.output(lv2, x_adjoint, y_adjoint)
+ # return value: (orig_return_values, tuple(adjoints))
+ return (lv2, (x_adjoint, y_adjoint))
+
+ The second example is returning multiple values and specifying the target
with `target_index`:
+
+ .. code-block:: python
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(
+ x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((),
dtype="float32")):
+ with R.dataflow():
+ lv1: R.Tensor((), dtype="float32") = R.sum(x, axis=None,
keepdims=False)
+ lv2: R.Tensor((), dtype="float32") = R.sum(y, axis=None,
keepdims=False)
+ R.output(lv1, lv2)
+ return (lv1, lv2)
+
+ After = relax.transform.Gradient("main", target_index=1)(Module)
+
+ The module after the Gradient pass will be:
+
+ .. code-block:: python
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(
+ x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3),
dtype="float32")
+ ) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((),
dtype="float32")):
+ with R.dataflow():
+ lv1: R.Tensor((), dtype="float32") = R.sum(x, axis=None,
keepdims=False)
+ lv2: R.Tensor((), dtype="float32") = R.sum(y, axis=None,
keepdims=False)
+ R.output(lv1, lv2)
+ return (lv1, lv2)
+
+ @R.function
+ def main_adjoint(
+ x: R.Tensor((3, 3), dtype="float32"), y: R.Tensor((3, 3),
dtype="float32")
+ ) -> R.Tuple(
+ R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((),
dtype="float32")),
+ R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tensor((3, 3),
dtype="float32")),
+ ):
+ with R.dataflow():
+ # original bindings
+ lv1: R.Tensor((), dtype="float32") = R.sum(x, axis=None,
keepdims=False)
+ lv2: R.Tensor((), dtype="float32") = R.sum(y, axis=None,
keepdims=False)
+ # bindings w.r.t. intermediate variables
+ # gradient of intermediate variables that is not related
to the target will not
+ # be calculated
+ lv2_adjoint: R.Tensor((), dtype="float32") = R.ones((),
dtype="float32")
+ # bindings w.r.t. parameters
+ x_adjoint: R.Tensor((3, 3), dtype="float32") = R.zeros((3,
3), dtype="float32")
+ y_adjoint: R.Tensor((3, 3), dtype="float32") =
R.broadcast_to(
+ lv2_adjoint, (3, 3)
+ )
+ R.output(lv1, lv2, x_adjoint, y_adjoint)
+ # return value: (orig_return_values, tuple(adjoints))
+ return ((lv1, lv2), (x_adjoint, y_adjoint))
+ """
+ if require_grads is not None and not isinstance(require_grads, list):
+ require_grads = [require_grads]
+
+ return _ffi_api.Gradient(func_name, require_grads, target_index) # type:
ignore
+
+
def ToNonDataflow() -> tvm.ir.transform.Pass:
"""Transform all dataflow structure to non-dataflow version.
diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc
new file mode 100644
index 0000000000..e7bdea6036
--- /dev/null
+++ b/src/relax/transform/gradient.cc
@@ -0,0 +1,469 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/gradient.cc
+ * \brief Reverse-mode automatic differentiation.
+ *
+ * Now only supports differentiating one function in the IRModule with one
dataflow block
+ * with respect to the only return value of the function, which needs to be
scalar.
+ */
+
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/nested_msg.h>
+#include <tvm/relax/op_attr_types.h>
+#include <tvm/relax/transform.h>
+
+#include <unordered_set>
+
+#include "../op/tensor/binary.h"
+#include "../op/tensor/create.h"
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+using AdjointMsg = NestedMsg<Expr>;
+
+// A tool class for GradientMutator
+// Visit the forward bindings and generate the backward bindings
+class BackwardBindingGenerator : private ExprVisitor {
+ public:
+ /*!
+ * \brief Generate the backward bindings for the corresponding
GradientMutator
+ *
+ * \param builder The BlockBuilder of GradientMutator, used to generate
bindings
+ * \param forward_block The forward DataflowBlock
+ * \param require_grads The Var list to differentiate w.r.t.
+ * \param target_var The target Var to differentiate
+ * \param orig_return_value The original return value of the function. The
new return value is a
+ * 2-tuple, containing the original return value, and a tuple of the
adjoints of parameters.
+ * \return The return expr of new adjoint function.
+ */
+ static Expr Generate(const BlockBuilder& builder, const DataflowBlock&
forward_block,
+ const Array<Var>& require_grads, const Var& target_var,
+ const Expr& orig_return_value) {
+ BackwardBindingGenerator generator(builder);
+
+ // Initialize the adjoint of target_var as ones op. We have already check
the target.
+ auto* target_sinfo = GetStructInfoAs<TensorStructInfoNode>(target_var);
+ const Expr& target_adjoint = ones(target_sinfo->shape.value(),
target_sinfo->dtype);
+ UpdateStructInfo(target_adjoint, GetRef<StructInfo>(target_sinfo));
+ generator.adjoint_msg_map_.Set(target_var, AdjointMsg(target_adjoint));
+
+ // We do reverse-mode ad, so visit bindings backwards
+ for (auto it = forward_block->bindings.rbegin(); it !=
forward_block->bindings.rend(); ++it) {
+ generator.VisitBinding(*it);
+ }
+
+ return generator.Epilogue(require_grads, orig_return_value);
+ }
+
+ private:
+ explicit BackwardBindingGenerator(const BlockBuilder& builder) :
builder_(builder) {}
+
+ void VisitBinding(const Binding& binding) final {
+ // TODO(chaofan, yixin): support other types of bindings
+ CHECK(binding->IsInstance<VarBindingNode>()) << "now only support
VarBindingNode";
+ auto* var_binding = binding.as<VarBindingNode>();
+
+ auto it = adjoint_msg_map_.find(var_binding->var);
+ if (it == adjoint_msg_map_.end()) {
+ // This var is not used in the following bindings
+ return;
+ }
+
+ // Meet the definition of binding->var
+ // Create the adjoint var and bind the adjoint value to it
+ EmitAdjoint(var_binding->var, (*it).second, true);
+
+ Expr value = var_binding->value;
+ // TODO(chaofan, yixin): support other types of binding values
+ CHECK(value->IsInstance<CallNode>() || value->IsInstance<TupleNode>() ||
+ value->IsInstance<TupleGetItemNode>() ||
value->IsInstance<VarNode>() ||
+ value->IsInstance<ConstantNode>())
+ << "now does not support the type of binding value: " << value;
+
+ ExprVisitor::VisitBinding_(var_binding);
+ }
+
+ // Handle the adjoint expr of the inputs of binding
+ // For call node, we would call the registered gradient functions
+ void VisitBinding_(const VarBindingNode* binding, const CallNode* call)
final {
+ static const OpAttrMap<FPrimalGradient>& gradient_op_map =
+ Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
+
+ Var adjoint_var = adjoint_var_map_[binding->var];
+ const Op& call_op = Downcast<Op>(call->op);
+ const Array<Expr>& partials =
+ gradient_op_map[call_op](binding->var, GetRef<Call>(call),
adjoint_var, builder_);
+ ICHECK(partials.size() == call->args.size()) << "partials number != inputs
number";
+
+ for (size_t i = 0; i < partials.size(); ++i) {
+ Expr partial = partials[i];
+ if (IsCallNoGrad(partial)) { // no grad: don't update
+ continue;
+ }
+ if (!partial->struct_info_.defined()) {
+ UpdateStructInfo(partial, GetStructInfo(call->args[i]));
+ }
+ UpdateAdjoint(call->args[i], partial);
+ }
+ }
+
+ // For Tuple nodes, we would iterate over the input tuple and update adjoint
exprs for each input
+ // e.g.
+ // a = (b, c)
+ // b_adjoint += a_adjoint_var[0], c_adjoint += a_adjoint_var[1]
+ // a = ((b, c), d)
+ // b_adjoint += a_adjoint_var[0][0], c_adjoint += a_adjoint_var[0][1],
+ // d_adjoint += a_adjoint_var[1]
+ //
+ // Here we use adjoint_var to simplify calculation
+ void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple)
final {
+ UpdateAdjoint(GetRef<Tuple>(tuple), adjoint_var_map_[binding->var]);
+ }
+
+ // For TupleGetItem nodes, we do a partial update
+ // e.g.
+ // b = a[0]
+ // a_adjoint[0] += b_adjoint_var
+ // If a_adjoint does not exist, we would create a zeros tuple as a_adjoint
first, and then add
+ void VisitBinding_(const VarBindingNode* binding, const TupleGetItemNode*
tuple_get_item) final {
+ ICHECK(tuple_get_item->tuple->IsInstance<VarNode>())
+ << "The tuple field of a TupleGetItem is not bound to a Var";
+ auto* tuple_sinfo =
GetStructInfoAs<TupleStructInfoNode>(tuple_get_item->tuple);
+ ICHECK(tuple_sinfo) << "The tuple field of a TupleGetItem must has a
TupleStructInfo";
+
+ const Var& tuple_var = Downcast<Var>(tuple_get_item->tuple);
+ if (adjoint_msg_map_.count(tuple_var) == 0) {
+ const AdjointMsg& init =
InitZerosAdjointNested(GetRef<StructInfo>(tuple_sinfo));
+ adjoint_msg_map_.Set(tuple_var, init);
+ }
+
+ adjoint_msg_map_.Set(tuple_var,
+ AddInAdjointMsg(adjoint_msg_map_[tuple_var],
tuple_get_item->index,
+
ExprToAdjointMsg(adjoint_var_map_[binding->var])));
+ }
+
+ // For assign nodes, we add the adjoint of output to the adjoint of input
+ void VisitBinding_(const VarBindingNode* binding, const DataflowVarNode*
var) final {
+ UpdateAdjoint(GetRef<Var>(var), adjoint_var_map_[binding->var]);
+ }
+
+ void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final {
+ UpdateAdjoint(GetRef<Var>(var), adjoint_var_map_[binding->var]);
+ }
+
+ // For constant nodes, we do not have to handle it because it does not
contribute to the adjoint
+ void VisitBinding_(const VarBindingNode* binding, const ConstantNode* var)
final { return; }
+
+ // Add partial (Expr type) to the adjoint of expr
+ void UpdateAdjoint(const Expr& expr, const Expr& partial) {
+ DecomposeNestedMsg(expr, ExprToAdjointMsg(partial), [&](Expr leaf,
AdjointMsg msg) {
+ if (leaf->IsInstance<VarNode>()) {
+ const Var& v = Downcast<Var>(leaf);
+ if (adjoint_msg_map_.count(v) == 0) {
+ adjoint_msg_map_.Set(v, msg);
+ } else {
+ adjoint_msg_map_.Set(v, TupleAwareAdd(adjoint_msg_map_[v], msg));
+ }
+ } else if (leaf->IsInstance<ConstantNode>()) {
+ // nothing to do
+ } else if (leaf->IsInstance<ShapeExprNode>()) {
+ // must be no grad
+ ICHECK(IsCallNoGrad(partial));
+ } else {
+ LOG(FATAL) << "UpdateAdjoint: leaf type not supported. Currently Var
and Constant leaves "
+ "are supported.";
+ }
+ });
+ }
+
+ // Transform the adjoint expressed as NestedMsg<Expr> into adjoint Expr, and
then emit it
+ // If the adjoint is assigned to a DataflowVar (the adjoint corresponds to a
non-output binding),
+ // it would be stored in adjoint_var_map_ for future lookup
+ Var EmitAdjoint(const Var& source_var, const AdjointMsg& adjoint, bool
is_dataflow_var) {
+ Var adjoint_var;
+ if (is_dataflow_var) {
+ adjoint_var = builder_->Emit(AdjointMsgToExpr(adjoint),
source_var->name_hint() + "_adjoint");
+ adjoint_var_map_.Set(source_var, adjoint_var);
+ } else {
+ adjoint_var =
+ builder_->EmitOutput(AdjointMsgToExpr(adjoint),
source_var->name_hint() + "_adjoint");
+ }
+ return adjoint_var;
+ }
+
+ // Handle the return value of the AD function.
+ // Returns the new return value, which would be like:
+ // Tuple(original_return_value,
+ // Tuple(adjoint_of_require_grads_1, adjoint_of_require_grads_2, ...))
+ Expr Epilogue(const Array<Var>& require_grads, const Expr&
orig_return_value) {
+ // create adjoint variables for inputs, and then bind adjoints
+ Array<Expr> out_adjoints;
+
+ for (Var var : require_grads) {
+ // If the var don't have adjoint msg, it do not contribute to the target
+ // so its adjoint is zeros
+ AdjointMsg adjoint =
+
adjoint_msg_map_.Get(var).value_or(InitZerosAdjointNested(GetStructInfo(var)));
+ Var adjoint_var = EmitAdjoint(var, adjoint, false);
+ out_adjoints.push_back(adjoint_var);
+ }
+
+ return Tuple({orig_return_value, Tuple(out_adjoints)});
+ }
+
+ static bool IsCallZeros(const Expr& expr) {
+ return expr->IsInstance<CallNode>() && Downcast<Call>(expr)->op ==
Op::Get("relax.zeros");
+ }
+
+ static bool IsCallNoGrad(const Expr& expr) {
+ return expr->IsInstance<CallNode>() &&
+ Downcast<Call>(expr)->op == Op::Get("relax.grad.no_grad");
+ }
+
+ static Expr AdjointMsgToExpr(AdjointMsg msg) {
+ return NestedMsgToExpr<Expr>(msg, [](Optional<Expr> leaf_expr) {
+ if (!leaf_expr.defined()) {
+ LOG(FATAL) << "Null should not exist in AdjointMsg.";
+ }
+ return leaf_expr.value();
+ });
+ }
+
+ static AdjointMsg ExprToAdjointMsg(Expr expr) {
+ return MapToNestedMsgBySInfo<Expr>(expr, [](Expr leaf) {
+ ICHECK(GetStructInfoAs<TensorStructInfoNode>(leaf))
+ << "The leaf of adjoint: " << leaf << " should have StructInfo and
be a Tensor.";
+ return AdjointMsg(leaf);
+ });
+ }
+
+ // Create a zeros AdjointMsg with specified struct info
+ // When sinfo is TupleStructInfo, we would create a nested zeros Tuple
+ static AdjointMsg InitZerosAdjointNested(const StructInfo& sinfo) {
+ return MapToNestedMsg<Expr>(sinfo, [](StructInfo sinfo) {
+ auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>();
+ ICHECK(tensor_sinfo) << "The leaf of adjoint should be a Tensor.";
+ ICHECK(tensor_sinfo->shape.defined()) << "Error: missing shape when
building zeros tuple.";
+ const Expr& init = zeros(tensor_sinfo->shape.value(),
tensor_sinfo->dtype);
+ UpdateStructInfo(init, sinfo);
+ return init;
+ });
+ }
+
+ // Return base + increment. A tuple-aware addition.
+ static AdjointMsg TupleAwareAdd(const AdjointMsg& base, const AdjointMsg&
increment) {
+ return CombineNestedMsg(base, increment, [&](Expr lhs, Expr rhs) {
+ // a small optimization: a+0=a, 0+a=a.
+ if (IsCallZeros(lhs)) {
+ return rhs;
+ } else if (IsCallZeros(rhs)) {
+ return lhs;
+ }
+ auto* sinfo = GetStructInfoAs<TensorStructInfoNode>(lhs);
+ ICHECK(sinfo) << "The leaf of adjoint should have StructInfo and be a
Tensor.";
+ ICHECK(GetStructInfoAs<TensorStructInfoNode>(rhs))
+ << "The leaf of adjoint should have StructInfo and be a Tensor.";
+ Expr res = add(lhs, rhs);
+ UpdateStructInfo(res, GetRef<StructInfo>(sinfo));
+ return res;
+ });
+ }
+
+ // Perform an addition in a specified position of tuple.
+ // e.g. tuple=(a, b, c), index=1, increment=d, then return (a, b+d, c)
+ static AdjointMsg AddInAdjointMsg(const AdjointMsg& adjoint, int index,
+ const AdjointMsg& increment) {
+ ICHECK(adjoint.IsNested()) << "The adjoint should be nested.";
+ Array<AdjointMsg> arr = adjoint.NestedArray();
+ ICHECK(index >= 0 && index < static_cast<int>(arr.size()));
+ arr.Set(index, TupleAwareAdd(arr[index], increment));
+ return AdjointMsg(arr);
+ }
+
+ // The block builder of the corresponding GradientMutator, to emit bindings
+ BlockBuilder builder_;
+ // Forward Var to its adjoint Var
+ Map<Var, Var> adjoint_var_map_;
+ // Forward Var to its adjoint NestedMsg<Expr>
+ // We use NestedMsg<Expr> to save the adjoint information (equivalent to
adjoint Expr)
+ // When emitting, adjoint information will be transformed into adjoint Expr
+ Map<Var, AdjointMsg> adjoint_msg_map_;
+};
+
+class GradientMutator : private ExprMutator {
+ public:
+ static IRModule Transform(IRModule mod, String func_name,
Optional<Array<Var>> require_grads,
+ int target_index) {
+ auto* old_func_ptr = mod->Lookup(func_name).as<FunctionNode>();
+ CHECK(old_func_ptr) << func_name << "is not a Relax Function";
+ auto old_func = GetRef<Function>(old_func_ptr);
+
+ // when require_grads is not specified, it would be set to all params of
the function
+ auto require_grads_value = require_grads.value_or(old_func->params);
+
+ CheckRequireGrads(require_grads_value, old_func->params, func_name);
+
+ Function new_func = CopyWithNewVars(old_func);
+ // map the parameter list into new params
+ for (size_t i = 0; i < require_grads_value.size(); ++i) {
+ int idx =
+ std::find(old_func->params.begin(), old_func->params.end(),
require_grads_value[i]) -
+ old_func->params.begin();
+ require_grads_value.Set(i, new_func->params[idx]);
+ }
+
+ GradientMutator mutator(mod, require_grads_value, target_index);
+ Function new_func_transformed =
Downcast<Function>(mutator.VisitExpr(new_func));
+
+ IRModule new_module = GetRef<IRModule>(mod.CopyOnWrite());
+ new_module->Add(GlobalVar(func_name + "_adjoint"), new_func_transformed);
+ return new_module;
+ }
+
+ private:
+ GradientMutator(const IRModule& module, const Array<Var>& require_grads, int
target_index)
+ : ExprMutator(module), require_grads_(require_grads),
target_index_(target_index) {}
+
+ Expr VisitExpr_(const FunctionNode* func) final {
+ CHECK(func->body->IsInstance<SeqExprNode>()) << "The body of the function
must be SeqExpr.";
+
+ Expr new_body = this->VisitExpr(func->body);
+
+ return Function(func->params, new_body, NullOpt, func->attrs);
+ }
+
+ Expr VisitExpr_(const SeqExprNode* seq_expr) final {
+ // TODO(chaofan, yixin): multiple blocks AD
+ CHECK(seq_expr->blocks.size() == 1) << "now only support one dataflow
block";
+ // TODO(chaofan, yixin): AD in non-dataflow block.
+ CHECK(seq_expr->blocks[0]->IsInstance<DataflowBlockNode>())
+ << "now only support one dataflow block";
+
+ // the return value should be a VarNode, and a scalar
+ orig_return_expr_ = seq_expr->body;
+ CheckAndSetTarget(seq_expr->body, target_index_);
+
+ BindingBlock new_block = this->VisitBindingBlock(seq_expr->blocks[0]);
+ return SeqExpr({new_block}, this->return_expr_);
+ }
+
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
+ builder_->BeginDataflowBlock();
+ // accept bindings in the original block
+ for (const auto& binding : block->bindings) {
+ this->VisitBinding(binding);
+ }
+
+ // generate backward bindings and the return value
+ return_expr_ = BackwardBindingGenerator::Generate(this->builder_,
GetRef<DataflowBlock>(block),
+ this->require_grads_,
this->target_var_,
+ orig_return_expr_);
+
+ return builder_->EndBlock();
+ }
+
+ static bool IsFloatTensorSInfo(const StructInfo& sinfo) {
+ auto* tensor_sinfo = sinfo.as<TensorStructInfoNode>();
+ return tensor_sinfo && tensor_sinfo->dtype.is_float();
+ }
+
+ // When the return value is a Var, it is the target;
+ // when the return value is a Tuple, the target is the target_index-th field
of the return value
+ // Check that the target should be a Var of scalar tensor struct_info
+ void CheckAndSetTarget(const Expr& e, int target_index) {
+ if (auto* var = e.as<VarNode>()) {
+ CHECK_EQ(target_index, 0) << "When the function has only one return
value, target_index can "
+ "only be 0. But the target_index specified
is "
+ << target_index;
+ target_var_ = GetRef<Var>(var);
+ } else if (auto* tuple = e.as<TupleNode>()) {
+ CHECK(target_index >= 0 && target_index <
static_cast<int>(tuple->fields.size()))
+ << "target_index should be in the range of the number of return
values of the function. "
+ "But the specified target_index is "
+ << target_index << ", while the number of return values is " <<
tuple->fields.size();
+ auto* var = tuple->fields[target_index].as<VarNode>();
+ CHECK(var) << "Target must be a Var, but the specified target is "
+ << tuple->fields[target_index];
+ target_var_ = GetRef<Var>(var);
+ } else {
+ LOG(FATAL) << "The return value of the function must be Var or Tuple.
However, the return "
+ "value of the given function is "
+ << e;
+ }
+ auto target_sinfo = GetStructInfo(target_var_);
+ CHECK(IsScalarTensor(target_sinfo) && IsFloatTensorSInfo(target_sinfo))
+ << "The differentiation target must be a float scalar (0-dim Tensor),
but the StructInfo "
+ "of the given target "
+ << target_var_ << " is " << GetStructInfo(target_var_);
+ }
+
+ // Check every Var in require_grads:
+ // 1. there should be no duplicate var
+ // 2. every var should be a parameter of the function
+ // 3. the type of the input var should be Tensor of floating point dtype, or
Tuple of that
+ static void CheckRequireGrads(const Array<Var>& require_grads, const
Array<Var>& func_params,
+ const String& func_name) {
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_set;
+ for (const auto& var : require_grads) {
+ CHECK(std::find(func_params.begin(), func_params.end(), var) !=
func_params.end())
+ << "There is no Var named " << var->name_hint() << " in the
parameters of the function "
+ << func_name;
+ CHECK_EQ(var_set.count(var), 0) << "Var " << var->name_hint() << "
appears more than once";
+ var_set.emplace(var);
+
+ CHECK(IsNestedTensorConditioned(GetStructInfo(var), IsFloatTensorSInfo))
+ << "Only Tensors of floating point dtype or Tuples of float "
+ "Tensors can require gradients, but the StructInfo of Var "
+ << var->name_hint() << " is " << GetStructInfo(var);
+ }
+ }
+
+ // differentiation sources
+ Array<Var> require_grads_;
+ // the differentiation target
+ int target_index_;
+ Var target_var_;
+ // the return value of the original function and the differentiated function
+ Expr orig_return_expr_;
+ Expr return_expr_;
+};
+
+namespace transform {
+
+Pass Gradient(String func_name, Optional<Array<Var>> require_grads, int
target_index) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod,
+
PassContext pc) {
+ return relax::GradientMutator::Transform(mod, func_name, require_grads,
target_index);
+ };
+ return CreateModulePass(/*pass_function=*/pass_func,
+ /*opt_level=*/0,
+ /*pass_name=*/"Gradient",
+ /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.Gradient").set_body_typed(Gradient);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/utils.cc b/src/relax/transform/utils.cc
index 9a19115f62..c0fde3bd4c 100644
--- a/src/relax/transform/utils.cc
+++ b/src/relax/transform/utils.cc
@@ -22,6 +22,19 @@
namespace tvm {
namespace relax {
+bool IsScalarTensor(const StructInfo& sinfo) {
+ if (!sinfo->IsInstance<TensorStructInfoNode>()) {
+ return false;
+ }
+ TensorStructInfo tensor_sinfo = Downcast<TensorStructInfo>(sinfo);
+ if (!tensor_sinfo->shape.defined() ||
!tensor_sinfo->shape->IsInstance<ShapeExprNode>()) {
+ return false;
+ }
+ return tensor_sinfo->shape.as<ShapeExprNode>()->values.size() == 0;
+}
+
+bool IsScalarTensor(const Expr& expr) { return
IsScalarTensor(GetStructInfo(expr)); }
+
bool IsNestedTensor(const StructInfo& sinfo) {
return IsNestedTensorConditioned(sinfo, [](const TensorStructInfo& sinfo) {
return true; });
}
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index d51fe53101..9334fd8347 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -143,6 +143,22 @@ IRModule MakeGroupedFunctions(
const std::unordered_map<const Object*, relay::GraphPartitioner::Group*>&
partition,
bool lift_constants = true);
+/*!
+ * \brief Check if the given StructInfo is a scalar tensor. The sinfo should
be an instance of
+ * TensorStructInfo; its shape must be ShapeExpr.
+ * \param sinfo The StructInfo to be checked.
+ * \return true if the given StructInfo is a scalar tensor.
+ */
+bool IsScalarTensor(const StructInfo& sinfo);
+
+/*!
+ * \brief Check if the given expr is a scalar tensor. Now the shape of the
tensor expr must be
+ * ShapeExpr.
+ * \param expr The expr to be checked.
+ * \return true if the given expr is a scalar tensor.
+ */
+bool IsScalarTensor(const Expr& expr);
+
/*!
* \brief Check if the given StructInfo is a nested tensor StructInfo
satisfying the given
* condition f_condition.
diff --git a/tests/python/relax/test_transform_gradient.py
b/tests/python/relax/test_transform_gradient.py
new file mode 100644
index 0000000000..1b3d174c13
--- /dev/null
+++ b/tests/python/relax/test_transform_gradient.py
@@ -0,0 +1,1164 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import pytest
+import tvm.testing
+from tvm import relax
+from tvm.ir.base import assert_structural_equal
+from tvm.script.parser import relax as R, tir as T, ir as I
+from tvm._ffi.base import TVMError
+import numpy as np
+
+
+def test_simple():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32")):
+ with R.dataflow():
+ gv = R.sum(x)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor(None, "float32",
ndim=0):
+ with R.dataflow():
+ gv: R.Tensor((), "float32") = R.sum(x, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 3), "float32")) ->
R.Tuple(R.Tensor(None, "float32", ndim=0),R.Tuple(R.Tensor(None, "float32",
ndim=2)),):
+ with R.dataflow():
+ gv: R.Tensor((), "float32") = R.sum(x, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ x_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, (3, 3))
+ R.output(gv, x_adjoint)
+ return (gv, (x_adjoint,))
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After, Expected)
+
+
+def test_assign_binding():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32")):
+ with R.dataflow():
+ lv1 = x
+ lv2 = lv1
+ gv = R.sum(lv2)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = x
+ lv2: R.Tensor((3, 3), "float32") = lv1
+ gv: R.Tensor((), "float32") = R.sum(lv2, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 3), "float32")) ->
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = x
+ lv2: R.Tensor((3, 3), "float32") = lv1
+ gv: R.Tensor((), "float32") = R.sum(lv2, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ lv2_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, (3, 3))
+ lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint
+ x_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint
+ R.output(gv, x_adjoint)
+ return (gv, (x_adjoint,))
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After, Expected)
+
+
+def test_multiple_uses():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32")):
+ with R.dataflow():
+ lv1 = R.add(x, x)
+ lv2 = R.add(lv1, x)
+ gv = R.sum(lv2)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = R.add(x, x)
+ lv2: R.Tensor((3, 3), "float32") = R.add(lv1, x)
+ gv: R.Tensor((), "float32") = R.sum(lv2, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 3), "float32")) ->
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = R.add(x, x)
+ lv2: R.Tensor((3, 3), "float32") = R.add(lv1, x)
+ gv: R.Tensor((), "float32") = R.sum(lv2, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ lv2_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, (3, 3))
+ lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint
+ lv: R.Tensor((3, 3), "float32") = R.add(lv2_adjoint,
lv1_adjoint)
+ x_adjoint: R.Tensor((3, 3), "float32") = R.add(lv, lv1_adjoint)
+ R.output(gv, x_adjoint)
+ return (gv, (x_adjoint,))
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After, Expected)
+
+
+def test_unused():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32")):
+ with R.dataflow():
+ lv1 = R.add(x, x)
+ lv2 = R.add(lv1, x)
+ gv = R.sum(x)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = R.add(x, x)
+ lv2: R.Tensor((3, 3), "float32") = R.add(lv1, x)
+ gv: R.Tensor((), "float32") = R.sum(x, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 3), "float32")) ->
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = R.add(x, x)
+ lv2: R.Tensor((3, 3), "float32") = R.add(lv1, x)
+ gv: R.Tensor((), "float32") = R.sum(x, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ x_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, (3, 3))
+ R.output(gv, x_adjoint)
+ return (gv, (x_adjoint,))
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After, Expected)
+
+
+def test_default_require_grads():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((3, 3), "float32"),
+ y: R.Tensor((3, 3), "float32"),
+ z: R.Tensor((3, 3), "float32"),
+ ):
+ with R.dataflow():
+ lv1 = R.add(x, y)
+ lv2 = R.add(lv1, z)
+ gv = R.sum(lv2)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected1:
+ @R.function
+ def main(
+ x: R.Tensor((3, 3), "float32"),
+ y: R.Tensor((3, 3), "float32"),
+ z: R.Tensor((3, 3), "float32"),
+ ) -> R.Tensor((), "float32"):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = R.add(x, y)
+ lv2: R.Tensor((3, 3), "float32") = R.add(lv1, z)
+ gv: R.Tensor((), "float32") = R.sum(lv2, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32"), z: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"),
R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32"))):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = R.add(x, y)
+ lv2: R.Tensor((3, 3), "float32") = R.add(lv1, z)
+ gv: R.Tensor((), "float32") = R.sum(lv2, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ lv2_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, (3, 3))
+ lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint
+ x_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint
+ y_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint
+ z_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint
+ R.output(gv, x_adjoint, y_adjoint, z_adjoint)
+ return (gv, (x_adjoint, y_adjoint, z_adjoint))
+ # fmt: on
+
+ After1 = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After1, Expected1)
+
+ # fmt: off
+ @I.ir_module
+ class Expected2:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32"), z: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = R.add(x, y)
+ lv2: R.Tensor((3, 3), "float32") = R.add(lv1, z)
+ gv: R.Tensor((), "float32") = R.sum(lv2, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32"), z: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"),
R.Tuple(R.Tensor((3, 3), "float32"))):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = R.add(x, y)
+ lv2: R.Tensor((3, 3), "float32") = R.add(lv1, z)
+ gv: R.Tensor((), "float32") = R.sum(lv2, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ lv2_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, (3, 3))
+ lv1_adjoint: R.Tensor((3, 3), "float32") = lv2_adjoint
+ x_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint
+ R.output(gv, x_adjoint)
+ return (gv, (x_adjoint,))
+ # fmt: on
+
+ After2 = relax.transform.Gradient("main",
require_grads=Before["main"].params[0])(Before)
+ assert_structural_equal(After2, Expected2)
+
+
+def test_target_index():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")):
+ with R.dataflow():
+ lv1 = x
+ lv2 = R.sum(x)
+ lv3 = R.sum(y)
+ R.output(lv1, lv2, lv3)
+ return (lv1, lv2, lv3)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")) -> R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((), "float32"),
R.Tensor((), "float32")):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = x
+ lv2: R.Tensor((), "float32") = R.sum(x, axis=None,
keepdims=False)
+ lv3: R.Tensor((), "float32") = R.sum(y, axis=None,
keepdims=False)
+ R.output(lv1, lv2, lv3)
+ return (lv1, lv2, lv3)
+
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")) -> R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((),
"float32"), R.Tensor((), "float32")), R.Tuple(R.Tensor((3, 3), "float32"),
R.Tensor((3, 3), "float32"))):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = x
+ lv2: R.Tensor((), "float32") = R.sum(x, axis=None,
keepdims=False)
+ lv3: R.Tensor((), "float32") = R.sum(y, axis=None,
keepdims=False)
+ lv3_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ x_adjoint: R.Tensor((3, 3), "float32") = R.zeros((3, 3),
"float32")
+ y_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(lv3_adjoint, (3, 3))
+ R.output(lv1, lv2, lv3, x_adjoint, y_adjoint)
+ return ((lv1, lv2, lv3), (x_adjoint, y_adjoint))
+ # fmt: on
+
+ After = relax.transform.Gradient("main", target_index=2)(Before)
+ assert_structural_equal(After, Expected)
+
+
+def test_tuple():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")),
+ y: R.Tensor((3, 3), "float32"),
+ z: R.Tensor((3, 3), "float32"),
+ ):
+ with R.dataflow():
+ lv1 = (y, z)
+ lv2 = x[0]
+ lv3 = lv1[0]
+ lv4 = R.add(lv2, lv3)
+ gv = R.sum(lv4)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")) ->
R.Tensor(None, "float32", ndim=0):
+ with R.dataflow():
+ lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = (y, z)
+ lv2: R.Tensor((3, 3), "float32") = x[0]
+ lv3: R.Tensor((3, 3), "float32") = lv1[0]
+ lv4: R.Tensor((3, 3), "float32") = R.add(lv2, lv3)
+ gv: R.Tensor((), "float32") = R.sum(lv4, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3),
"float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tuple(R.Tensor((3,
3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32"),
R.Tensor((3, 3), "float32"))):
+ with R.dataflow():
+ lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = (y, z)
+ lv2: R.Tensor((3, 3), "float32") = x[0]
+ lv3: R.Tensor((3, 3), "float32") = lv1[0]
+ lv4: R.Tensor((3, 3), "float32") = R.add(lv2, lv3)
+ gv: R.Tensor((), "float32") = R.sum(lv4, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ lv4_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, (3, 3))
+ lv3_adjoint: R.Tensor((3, 3), "float32") = lv4_adjoint
+ lv2_adjoint: R.Tensor((3, 3), "float32") = lv4_adjoint
+ lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+ lv1_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")) = (lv3_adjoint, lv)
+ lv11: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+ x_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")) = (lv2_adjoint, lv11)
+ y_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint[0]
+ z_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint[1]
+ R.output(gv, x_adjoint, y_adjoint, z_adjoint)
+ return (gv, (x_adjoint, y_adjoint, z_adjoint))
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After, Expected)
+
+
+def test_tuple_assignment():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")):
+ with R.dataflow():
+ lv1 = (x, y)
+ lv4 = lv1[0]
+ lv7 = R.add(lv4, x)
+ lv2 = lv1
+ lv3 = lv2[0]
+ lv5 = R.add(lv3, lv7)
+ gv = R.sum(lv5)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")) -> R.Tensor((), "float32"):
+ with R.dataflow():
+ lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = (x, y)
+ lv4: R.Tensor((3, 3), "float32") = lv1[0]
+ lv7: R.Tensor((3, 3), "float32") = R.add(lv4, x)
+ lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = lv1
+ lv3: R.Tensor((3, 3), "float32") = lv2[0]
+ lv5: R.Tensor((3, 3), "float32") = R.add(lv3, lv7)
+ gv: R.Tensor((), "float32") = R.sum(lv5, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3),
"float32"), R.Tensor((3, 3), "float32"))):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = (x, y)
+ lv4: R.Tensor((3, 3), "float32") = lv1[0]
+ lv7: R.Tensor((3, 3), "float32") = R.add(lv4, x)
+ lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = lv1
+ lv3: R.Tensor((3, 3), "float32") = lv2[0]
+ lv5: R.Tensor((3, 3), "float32") = R.add(lv3, lv7)
+ gv: R.Tensor((), "float32") = R.sum(lv5, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ lv5_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, (3, 3))
+ lv3_adjoint: R.Tensor((3, 3), "float32") = lv5_adjoint
+ lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+ lv2_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")) = (lv3_adjoint, lv)
+ lv7_adjoint: R.Tensor((3, 3), "float32") = lv5_adjoint
+ lv4_adjoint: R.Tensor((3, 3), "float32") = lv7_adjoint
+ lv11: R.Tensor((3, 3), "float32") = lv2_adjoint[0]
+ lv21: R.Tensor((3, 3), "float32") = R.add(lv11, lv4_adjoint)
+ lv31: R.Tensor((3, 3), "float32") = lv2_adjoint[1]
+ lv1_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")) = (lv21, lv31)
+ lv41: R.Tensor((3, 3), "float32") = lv1_adjoint[0]
+ x_adjoint: R.Tensor((3, 3), "float32") = R.add(lv7_adjoint,
lv41)
+ y_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint[1]
+ R.output(gv, x_adjoint, y_adjoint)
+ return (gv, (x_adjoint, y_adjoint))
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After, Expected)
+
+
+def test_tuple_nested():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")), R.Tensor((3, 3), "float32")),
+ y: R.Tensor((3, 3), "float32"),
+ z: R.Tensor((3, 3), "float32"),
+ u: R.Tensor((3, 3), "float32"),
+ ):
+ with R.dataflow():
+ lv1 = ((y, z), u)
+ lv2 = x[0]
+ lv3 = lv2[0]
+ lv4 = lv1[0]
+ lv5 = lv4[1]
+ lv6 = R.add(lv3, lv5)
+ lv7 = x[1]
+ lv8 = R.add(lv6, lv7)
+ gv = R.sum(lv8)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")), R.Tensor((3, 3), "float32")), y: R.Tensor((3, 3), "float32"),
z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32")) -> R.Tensor((),
"float32"):
+ with R.dataflow():
+ lv1: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")), R.Tensor((3, 3), "float32")) = ((y, z), u)
+ lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = x[0]
+ lv3: R.Tensor((3, 3), "float32") = lv2[0]
+ lv4: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = lv1[0]
+ lv5: R.Tensor((3, 3), "float32") = lv4[1]
+ lv6: R.Tensor((3, 3), "float32") = R.add(lv3, lv5)
+ lv7: R.Tensor((3, 3), "float32") = x[1]
+ lv8: R.Tensor((3, 3), "float32") = R.add(lv6, lv7)
+ gv: R.Tensor((), "float32") = R.sum(lv8, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"),
R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")), y: R.Tensor((3, 3),
"float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32")) ->
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tuple(R.Tuple(R.Tensor((3, 3),
"float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")),
R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32"))):
+ with R.dataflow():
+ lv1: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")), R.Tensor((3, 3), "float32")) = ((y, z), u)
+ lv2: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = x[0]
+ lv3: R.Tensor((3, 3), "float32") = lv2[0]
+ lv4: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = lv1[0]
+ lv5: R.Tensor((3, 3), "float32") = lv4[1]
+ lv6: R.Tensor((3, 3), "float32") = R.add(lv3, lv5)
+ lv7: R.Tensor((3, 3), "float32") = x[1]
+ lv8: R.Tensor((3, 3), "float32") = R.add(lv6, lv7)
+ gv: R.Tensor((), "float32") = R.sum(lv8, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ lv8_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, (3, 3))
+ lv7_adjoint: R.Tensor((3, 3), "float32") = lv8_adjoint
+ lv6_adjoint: R.Tensor((3, 3), "float32") = lv8_adjoint
+ lv5_adjoint: R.Tensor((3, 3), "float32") = lv6_adjoint
+ lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+ lv4_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")) = (lv, lv5_adjoint)
+ lv3_adjoint: R.Tensor((3, 3), "float32") = lv6_adjoint
+ lv11: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+ lv2_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")) = (lv3_adjoint, lv11)
+ lv21: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+ lv1_adjoint: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"),
R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")) = (lv4_adjoint, lv21)
+ x_adjoint: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"),
R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")) = (lv2_adjoint,
lv7_adjoint)
+ lv31: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = lv1_adjoint[0]
+ y_adjoint: R.Tensor((3, 3), "float32") = lv31[0]
+ z_adjoint: R.Tensor((3, 3), "float32") = lv31[1]
+ u_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint[1]
+ R.output(gv, x_adjoint, y_adjoint, z_adjoint, u_adjoint)
+ return (gv, (x_adjoint, y_adjoint, z_adjoint, u_adjoint))
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After, Expected)
+
+
+def test_tuple_update():
+ """One tensor `x` is used in and out of tuple many times."""
+
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")):
+ with R.dataflow():
+ lv0 = (x, y)
+ lv1 = R.add(x, y)
+ lv2 = lv0[0]
+ lv3 = R.add(lv2, y)
+ lv4 = R.add(lv1, lv3)
+ lv5 = (x, y)
+ lv6 = lv5[0]
+ lv7 = lv0[0]
+ lv8 = R.add(lv4, lv6)
+ lv9 = R.add(lv8, lv7)
+ gv = R.sum(lv9)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")) -> R.Tensor((), "float32"):
+ with R.dataflow():
+ lv0: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = (x, y)
+ lv1: R.Tensor((3, 3), "float32") = R.add(x, y)
+ lv2: R.Tensor((3, 3), "float32") = lv0[0]
+ lv3: R.Tensor((3, 3), "float32") = R.add(lv2, y)
+ lv4: R.Tensor((3, 3), "float32") = R.add(lv1, lv3)
+ lv5: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = (x, y)
+ lv6: R.Tensor((3, 3), "float32") = lv5[0]
+ lv7: R.Tensor((3, 3), "float32") = lv0[0]
+ lv8: R.Tensor((3, 3), "float32") = R.add(lv4, lv6)
+ lv9: R.Tensor((3, 3), "float32") = R.add(lv8, lv7)
+ gv: R.Tensor((), "float32") = R.sum(lv9, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3),
"float32"), R.Tensor((3, 3), "float32"))):
+ with R.dataflow():
+ lv0: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = (x, y)
+ lv1: R.Tensor((3, 3), "float32") = R.add(x, y)
+ lv2: R.Tensor((3, 3), "float32") = lv0[0]
+ lv3: R.Tensor((3, 3), "float32") = R.add(lv2, y)
+ lv4: R.Tensor((3, 3), "float32") = R.add(lv1, lv3)
+ lv5: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = (x, y)
+ lv6: R.Tensor((3, 3), "float32") = lv5[0]
+ lv7: R.Tensor((3, 3), "float32") = lv0[0]
+ lv8: R.Tensor((3, 3), "float32") = R.add(lv4, lv6)
+ lv9: R.Tensor((3, 3), "float32") = R.add(lv8, lv7)
+ gv: R.Tensor((), "float32") = R.sum(lv9, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ lv9_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, (3, 3))
+ lv8_adjoint: R.Tensor((3, 3), "float32") = lv9_adjoint
+ lv7_adjoint: R.Tensor((3, 3), "float32") = lv9_adjoint
+ lv6_adjoint: R.Tensor((3, 3), "float32") = lv8_adjoint
+ lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+ lv5_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")) = (lv6_adjoint, lv)
+ lv4_adjoint: R.Tensor((3, 3), "float32") = lv8_adjoint
+ lv3_adjoint: R.Tensor((3, 3), "float32") = lv4_adjoint
+ lv2_adjoint: R.Tensor((3, 3), "float32") = lv3_adjoint
+ lv1_adjoint: R.Tensor((3, 3), "float32") = lv4_adjoint
+ lv11: R.Tensor((3, 3), "float32") = R.add(lv7_adjoint,
lv2_adjoint)
+ lv21: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+ lv0_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")) = (lv11, lv21)
+ lv31: R.Tensor((3, 3), "float32") = lv5_adjoint[0]
+ lv41: R.Tensor((3, 3), "float32") = R.add(lv31, lv1_adjoint)
+ lv51: R.Tensor((3, 3), "float32") = lv0_adjoint[0]
+ x_adjoint: R.Tensor((3, 3), "float32") = R.add(lv41, lv51)
+ lv61: R.Tensor((3, 3), "float32") = lv5_adjoint[1]
+ lv71: R.Tensor((3, 3), "float32") = R.add(lv61, lv3_adjoint)
+ lv81: R.Tensor((3, 3), "float32") = R.add(lv71, lv1_adjoint)
+ lv91: R.Tensor((3, 3), "float32") = lv0_adjoint[1]
+ y_adjoint: R.Tensor((3, 3), "float32") = R.add(lv81, lv91)
+ R.output(gv, x_adjoint, y_adjoint)
+ return (gv, (x_adjoint, y_adjoint))
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After, Expected)
+
+
+def test_tuple_op_simple():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((6,), "float32")):
+ with R.dataflow():
+ lv1 = R.split(x, 2)
+ lv2 = R.concat(lv1)
+ gv = R.sum(lv2)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((6,), "float32")) -> R.Tensor((), "float32"):
+ with R.dataflow():
+ lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,),
"float32")) = R.split(x, indices_or_sections=2, axis=0)
+ lv2: R.Tensor((6,), "float32") = R.concat(lv1, axis=0)
+ gv: R.Tensor((), "float32") = R.sum(lv2, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tensor((6,), "float32")) -> R.Tuple(R.Tensor((),
"float32"), R.Tuple(R.Tensor((6,), "float32"))):
+ with R.dataflow():
+ lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,),
"float32")) = R.split(x, indices_or_sections=2, axis=0)
+ lv2: R.Tensor((6,), "float32") = R.concat(lv1, axis=0)
+ gv: R.Tensor((), "float32") = R.sum(lv2, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ lv2_adjoint: R.Tensor((6,), "float32") =
R.broadcast_to(gv_adjoint, (6,))
+ lv1_adjoint: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,),
"float32")) = R.split(lv2_adjoint, indices_or_sections=[3], axis=0)
+ x_adjoint: R.Tensor((6,), "float32") = R.concat(lv1_adjoint,
axis=0)
+ R.output(gv, x_adjoint)
+ return (gv, (x_adjoint,))
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After, Expected)
+
+
+def test_tuple_op_construct():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((3,), "float32"), y: R.Tuple(R.Tensor((3, ),
"float32"), R.Tensor((3, ), "float32")),):
+ with R.dataflow():
+ lv1 = (x, x)
+ lv2 = R.concat(lv1)
+ lv3 = R.concat((x, x))
+ lv4 = R.concat(y)
+ lv5 = R.add(lv2, lv3)
+ lv6 = R.add(lv5, lv4)
+ gv = R.sum(lv6)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((3,), "float32"), y: R.Tuple(R.Tensor((3,),
"float32"), R.Tensor((3,), "float32"))) -> R.Tensor((), "float32"):
+ with R.dataflow():
+ lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,),
"float32")) = (x, x)
+ lv2: R.Tensor((6,), "float32") = R.concat(lv1, axis=0)
+ lv3: R.Tensor((6,), "float32") = R.concat((x, x), axis=0)
+ lv4: R.Tensor((6,), "float32") = R.concat(y, axis=0)
+ lv5: R.Tensor((6,), "float32") = R.add(lv2, lv3)
+ lv6: R.Tensor((6,), "float32") = R.add(lv5, lv4)
+ gv: R.Tensor((), "float32") = R.sum(lv6, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tensor((3,), "float32"), y:
R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32"))) ->
R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3,), "float32"),
R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")))):
+ with R.dataflow():
+ lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,),
"float32")) = (x, x)
+ lv2: R.Tensor((6,), "float32") = R.concat(lv1, axis=0)
+ lv3: R.Tensor((6,), "float32") = R.concat((x, x), axis=0)
+ lv4: R.Tensor((6,), "float32") = R.concat(y, axis=0)
+ lv5: R.Tensor((6,), "float32") = R.add(lv2, lv3)
+ lv6: R.Tensor((6,), "float32") = R.add(lv5, lv4)
+ gv: R.Tensor((), "float32") = R.sum(lv6, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ lv6_adjoint: R.Tensor((6,), "float32") =
R.broadcast_to(gv_adjoint, (6,))
+ lv5_adjoint: R.Tensor((6,), "float32") = lv6_adjoint
+ lv4_adjoint: R.Tensor((6,), "float32") = lv6_adjoint
+ lv3_adjoint: R.Tensor((6,), "float32") = lv5_adjoint
+ lv2_adjoint: R.Tensor((6,), "float32") = lv5_adjoint
+ lv1_adjoint: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,),
"float32")) = R.split(lv2_adjoint, indices_or_sections=[3], axis=0)
+ lv: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,),
"float32")) = R.split(lv3_adjoint, indices_or_sections=[3], axis=0)
+ lv11: R.Tensor((3,), "float32") = lv[0]
+ lv21: R.Tensor((3,), "float32") = lv[1]
+ lv31: R.Tensor((3,), "float32") = R.add(lv11, lv21)
+ lv41: R.Tensor((3,), "float32") = lv1_adjoint[0]
+ lv51: R.Tensor((3,), "float32") = R.add(lv31, lv41)
+ lv61: R.Tensor((3,), "float32") = lv1_adjoint[1]
+ x_adjoint: R.Tensor((3,), "float32") = R.add(lv51, lv61)
+ y_adjoint: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,),
"float32")) = R.split(lv4_adjoint, indices_or_sections=[3], axis=0)
+ R.output(gv, x_adjoint, y_adjoint)
+ return (gv, (x_adjoint, y_adjoint))
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After, Expected)
+
+
+def test_tuple_op_const():
+ c1 = R.const(np.zeros(3).astype(np.float32))
+ c2 = R.const(np.zeros(3).astype(np.float32))
+ c3 = R.const(np.zeros(3).astype(np.float32))
+
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((3,), "float32")):
+ with R.dataflow():
+ lv1 = R.concat((c1, c2))
+ lv2 = R.concat((c3, x))
+ lv3 = R.concat((x, x))
+ lv4 = R.add(lv1, lv2)
+ lv5 = R.add(lv4, lv3)
+ gv = R.sum(lv5)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((3,), "float32")) -> R.Tensor((), "float32"):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((6,), "float32") = R.concat((c1, c2), axis=0)
+ lv2: R.Tensor((6,), "float32") = R.concat((c3, x), axis=0)
+ lv3: R.Tensor((6,), "float32") = R.concat((x, x), axis=0)
+ lv4: R.Tensor((6,), "float32") = R.add(lv1, lv2)
+ lv5: R.Tensor((6,), "float32") = R.add(lv4, lv3)
+ gv: R.Tensor((), "float32") = R.sum(lv5, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tensor((3,), "float32")) -> R.Tuple(R.Tensor((),
"float32"), R.Tuple(R.Tensor((3,), "float32"))):
+ # block 0
+ with R.dataflow():
+ lv1: R.Tensor((6,), "float32") = R.concat((c1, c2), axis=0)
+ lv2: R.Tensor((6,), "float32") = R.concat((c3, x), axis=0)
+ lv3: R.Tensor((6,), "float32") = R.concat((x, x), axis=0)
+ lv4: R.Tensor((6,), "float32") = R.add(lv1, lv2)
+ lv5: R.Tensor((6,), "float32") = R.add(lv4, lv3)
+ gv: R.Tensor((), "float32") = R.sum(lv5, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ lv5_adjoint: R.Tensor((6,), "float32") =
R.broadcast_to(gv_adjoint, (6,))
+ lv4_adjoint: R.Tensor((6,), "float32") = lv5_adjoint
+ lv3_adjoint: R.Tensor((6,), "float32") = lv5_adjoint
+ lv2_adjoint: R.Tensor((6,), "float32") = lv4_adjoint
+ lv1_adjoint: R.Tensor((6,), "float32") = lv4_adjoint
+ lv: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,),
"float32")) = R.split(lv3_adjoint, indices_or_sections=[3], axis=0)
+ lv11: R.Tensor((3,), "float32") = lv[0]
+ lv21: R.Tensor((3,), "float32") = lv[1]
+ lv31: R.Tensor((3,), "float32") = R.add(lv11, lv21)
+ lv41: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,),
"float32")) = R.split(lv2_adjoint, indices_or_sections=[3], axis=0)
+ lv51: R.Tensor((3,), "float32") = lv41[1]
+ x_adjoint: R.Tensor((3,), "float32") = R.add(lv31, lv51)
+ R.output(gv, x_adjoint)
+ return (gv, (x_adjoint,))
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After["main_adjoint"], Expected["main_adjoint"])
+
+
+def test_const():
+ """const could be used in variable assignment, call argument, and as a
part of tuple"""
+ cst = relax.const(np.ones((3, 3)), "float32")
+
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")):
+ with R.dataflow():
+ lv1 = R.add(x, cst)
+ lv2 = cst
+ lv3 = (cst, (cst, lv1))
+ lv4 = lv3[1]
+ lv5 = lv4[1]
+ lv6 = R.subtract(lv5, lv2)
+ gv = R.sum(lv6)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")) -> R.Tensor((), "float32"):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = R.add(x, cst)
+ lv2: R.Tensor((3, 3), "float32") = cst
+ lv3: R.Tuple(R.Tensor((3, 3), "float32"), R.Tuple(R.Tensor((3,
3), "float32"), R.Tensor((3, 3), "float32"))) = (cst, (cst, lv1))
+ lv4: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = lv3[1]
+ lv5: R.Tensor((3, 3), "float32") = lv4[1]
+ lv6: R.Tensor((3, 3), "float32") = R.subtract(lv5, lv2)
+ gv: R.Tensor((), "float32") = R.sum(lv6, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3),
"float32"), R.Tensor((3, 3), "float32"))):
+ with R.dataflow():
+ lv1: R.Tensor((3, 3), "float32") = R.add(x, cst)
+ lv2: R.Tensor((3, 3), "float32") = cst
+ lv3: R.Tuple(R.Tensor((3, 3), "float32"), R.Tuple(R.Tensor((3,
3), "float32"), R.Tensor((3, 3), "float32"))) = (cst, (cst, lv1))
+ lv4: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = lv3[1]
+ lv5: R.Tensor((3, 3), "float32") = lv4[1]
+ lv6: R.Tensor((3, 3), "float32") = R.subtract(lv5, lv2)
+ gv: R.Tensor((), "float32") = R.sum(lv6, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), "float32") = R.ones((), "float32")
+ lv6_adjoint: R.Tensor((3, 3), "float32") =
R.broadcast_to(gv_adjoint, (3, 3))
+ lv5_adjoint: R.Tensor((3, 3), "float32") = lv6_adjoint
+ lv: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+ lv4_adjoint: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,
3), "float32")) = (lv, lv5_adjoint)
+ lv11: R.Tensor((3, 3), "float32") = R.zeros((3, 3), "float32")
+ lv3_adjoint: R.Tuple(R.Tensor((3, 3), "float32"),
R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))) = (lv11,
lv4_adjoint)
+ lv2_adjoint: R.Tensor((3, 3), "float32") =
R.negative(lv6_adjoint)
+ lv21: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3),
"float32")) = lv3_adjoint[1]
+ lv1_adjoint: R.Tensor((3, 3), "float32") = lv21[1]
+ x_adjoint: R.Tensor((3, 3), "float32") = lv1_adjoint
+ y_adjoint: R.Tensor((3, 3), "float32") = R.zeros((3, 3),
"float32")
+ R.output(gv, x_adjoint, y_adjoint)
+ return (gv, (x_adjoint, y_adjoint))
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After, Expected)
+
+
+def test_params_copy():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x0: R.Tensor((3, 3), "float32"),
+ x1: R.Tensor((3, 3), "float32"),
+ x2: R.Tensor((3, 3), "float32"),
+ x3: R.Tensor((3, 3), "float32"),
+ ):
+ with R.dataflow():
+ lv0 = R.add(x0, x1)
+ lv1 = R.add(x2, x3)
+ lv2 = R.add(lv0, lv1)
+ gv = R.sum(lv2)
+ R.output(gv)
+ return gv
+
+ After = relax.transform.Gradient("main")(Before)
+ assert len(Before["main"].params) == len(After["main"].params)
+ assert len(Before["main"].params) == len(After["main_adjoint"].params)
+ for i in range(len(After["main"].params)):
+ assert Before["main"].params[i] == After["main"].params[i]
+ assert Before["main"].params[i] != After["main_adjoint"].params[i]
+
+
+def test_function_copy():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x0: R.Tensor((3, 3), "float32"),
+ x1: R.Tensor((3, 3), "float32"),
+ x2: R.Tensor((3, 3), "float32"),
+ x3: R.Tensor((3, 3), "float32"),
+ ):
+ with R.dataflow():
+ lv0 = R.add(x0, x1)
+ lv1 = R.add(x2, x3)
+ lv2 = R.add(lv0, lv1)
+ gv = R.sum(lv2)
+ R.output(gv)
+ return gv
+
+ After = relax.transform.Gradient("main")(Before)
+
+ # After should have the same "main" function as Before
+ assert_structural_equal(Before["main"], After["main"])
+
+ # the first bindings of After["main_adjoint"] should be the same as
Before["main"]
+ old_bindings = Before["main"].body.blocks[0].bindings
+ old_bindings_len = len(old_bindings)
+ new_bindings =
After["main_adjoint"].body.blocks[0].bindings[:old_bindings_len]
+ assert_structural_equal(old_bindings, new_bindings, True)
+
+
+def test_report_error():
+ @I.ir_module
+ class TargetNotTensor:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32")):
+ with R.dataflow():
+ lv1 = R.sum(x)
+ gv = R.tuple(lv1, lv1)
+ R.output(gv)
+ return gv
+
+ with pytest.raises(TVMError):
+ relax.transform.Gradient("main")(TargetNotTensor)
+
+ @I.ir_module
+ class TargetNotScalar:
+ @R.function
+ def main(x0: R.Tensor((3, 3), "float32"), x1: R.Tensor((3, 3),
"float32")):
+ with R.dataflow():
+ gv = R.add(x0, x1)
+ R.output(gv)
+ return gv
+
+ with pytest.raises(TVMError):
+ relax.transform.Gradient("main")(TargetNotScalar)
+
+ @I.ir_module
+ class TargetNotFloat:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32")):
+ with R.dataflow():
+ gv = R.const(1)
+ R.output(gv)
+ return gv
+
+ with pytest.raises(TVMError):
+ relax.transform.Gradient("main")(TargetNotFloat)
+
+ @I.ir_module
+ class ReturnScalarAndWrongTargetIndex:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32")):
+ with R.dataflow():
+ gv = R.sum(x)
+ R.output(gv)
+ return gv
+
+ with pytest.raises(TVMError):
+ relax.transform.Gradient("main",
target_index=1)(ReturnScalarAndWrongTargetIndex)
+
+ @I.ir_module
+ class ReturnTupleAndWrongTargetIndex:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3),
"float32")):
+ with R.dataflow():
+ gv1 = R.sum(x)
+ gv2 = R.sum(y)
+ R.output(gv1, gv2)
+ return gv1, gv2
+
+ with pytest.raises(TVMError):
+ relax.transform.Gradient("main",
target_index=2)(ReturnTupleAndWrongTargetIndex)
+
+ @I.ir_module
+ class IndexedTargetNotVar:
+ @R.function
+ def main(x: R.Tensor((3, 3), "float32")):
+ with R.dataflow():
+ gv = R.sum(x)
+ R.output(gv)
+ return gv, (gv, gv)
+
+ with pytest.raises(TVMError):
+ relax.transform.Gradient("main", target_index=1)(IndexedTargetNotVar)
+
+ @I.ir_module
+ class NoDataflow:
+ @R.function
+ def main(x0: R.Tensor((3, 3), "float32")):
+ gv = R.sum(x0)
+ return gv
+
+ with pytest.raises(TVMError):
+ relax.transform.Gradient("main")(NoDataflow)
+
+ @I.ir_module
+ class MultiBlocks:
+ @R.function
+ def main(x0: R.Tensor((3, 3), "float32"), x1: R.Tensor((3, 3),
"float32")):
+ # block 0
+ with R.dataflow():
+ gv = R.add(x0, x1)
+ R.output(gv)
+ # block 1
+ gv1 = R.sum(x0)
+ return gv1
+
+ with pytest.raises(TVMError):
+ relax.transform.Gradient("main")(MultiBlocks)
+
+ @I.ir_module
+ class NormalModule:
+ @R.function
+ def main(x0: R.Tensor((3, 3), "float32"), x1: R.Tensor((3, 3),
"float32")):
+ with R.dataflow():
+ gv = R.sum(x0)
+ R.output(gv)
+ return gv
+
+ @T.prim_func
+ def sum(
+ rxplaceholder: T.Buffer((T.int64(3), T.int64(3)), "float32"),
+ rxplaceholder_red: T.Buffer((), "float32"),
+ ):
+ T.func_attr({"tir.noalias": True})
+ for k0, k1 in T.grid(T.int64(3), T.int64(3)):
+ with T.block("rxplaceholder_red"):
+ v_k0, v_k1 = T.axis.remap("RR", [k0, k1])
+ T.reads(rxplaceholder[v_k0, v_k1])
+ T.writes(rxplaceholder_red[()])
+ with T.init():
+ rxplaceholder_red[()] = T.float32(0)
+ rxplaceholder_red[()] = rxplaceholder_red[()] +
rxplaceholder[v_k0, v_k1]
+
+ # no such function
+ with pytest.raises(ValueError):
+ relax.transform.Gradient("main1")(NormalModule)
+ # wrong function type
+ with pytest.raises(TVMError):
+ relax.transform.Gradient("sum")(NormalModule)
+ # no such var
+ with pytest.raises(TVMError):
+ relax.transform.Gradient("main",
require_grads=MultiBlocks["main"].params[0])(NormalModule)
+
+ @I.ir_module
+ class IntDtype:
+ @R.function
+ def main(x: R.Tensor((3, 3), "int64")):
+ with R.dataflow():
+ lv1 = R.add(x, x)
+ gv = R.sum(lv1)
+ R.output(gv)
+ return gv
+
+ with pytest.raises(TVMError):
+ relax.transform.Gradient("main")(IntDtype)
+
+ @I.ir_module
+ class IntDtypeTuple:
+ @R.function
+ def main(x: R.Tuple(R.Tensor((3, 3), "int64"), R.Tensor((3, 3),
"int64"))):
+ with R.dataflow():
+ lv1 = x[0]
+ lv2 = x[1]
+ lv3 = R.add(lv1, lv2)
+ gv = R.sum(lv3)
+ R.output(gv)
+ return gv
+
+ with pytest.raises(TVMError):
+ relax.transform.Gradient("main")(IntDtypeTuple)
+
+
+def test_shape_expr():
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((3, 4), "float32")):
+ with R.dataflow():
+ s = R.shape([3, 2, 2])
+ lv = R.reshape(x, s)
+ gv = R.sum(lv)
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 4), dtype="float32")) ->
R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 4),
dtype="float32"))):
+ with R.dataflow():
+ s: R.Shape([3, 2, 2]) = R.shape([3, 2, 2])
+ lv = R.reshape(x, s)
+ gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None,
keepdims=False)
+ gv_adjoint: R.Tensor((), dtype="float32") =
R.ones(R.shape([]), dtype="float32")
+ lv_adjoint : R.Tensor([3, 2, 2], "float32") =
R.broadcast_to(gv_adjoint, R.shape([3, 2, 2]))
+ x_adjoint: R.Tensor((3, 4), dtype="float32") =
R.reshape(lv_adjoint, R.shape([3, 4]))
+ R.output(gv, x_adjoint)
+ return (gv, (x_adjoint,))
+
+ @R.function
+ def main(x: R.Tensor((3, 4), dtype="float32")) -> R.Tensor((),
dtype="float32"):
+ with R.dataflow():
+ s: R.Shape([3, 2, 2]) = R.shape([3, 2, 2])
+ lv = R.reshape(x, s)
+ gv: R.Tensor((), dtype="float32") = R.sum(lv, axis=None,
keepdims=False)
+ R.output(gv)
+ return gv
+ # fmt: on
+
+ After = relax.transform.Gradient("main")(Before)
+ assert_structural_equal(After, Expected)
+
+
+def test_mlp_script():
+ """
+ An example of single layer multi-layer perceptron. You can add extra
layers if you want.
+
+ For n-layer perceptron, see test_transform_gradient_numeric.py.
+ """
+ # fmt: off
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((3, 10), "float32"),
+ w0: R.Tensor((10, 5), "float32"),
+ b0: R.Tensor((5,), "float32"),
+ label: R.Tensor((3, 5), "float32"),
+ ):
+ with R.dataflow():
+ lv0 = R.matmul(x, w0)
+ out = R.add(lv0, b0)
+ logits = R.nn.log_softmax(out)
+ loss = R.nn.cross_entropy_with_logits(logits, label)
+ R.output(loss)
+ return loss
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main_adjoint(x: R.Tensor((3, 10), dtype="float32"), w0:
R.Tensor((10, 5), dtype="float32"), b0: R.Tensor((5,), dtype="float32"), label:
R.Tensor((3, 5), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"),
R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((5,), dtype="float32"))):
+ with R.dataflow():
+ lv0: R.Tensor((3, 5), dtype="float32") = R.matmul(x, w0,
out_dtype="void")
+ out: R.Tensor((3, 5), dtype="float32") = R.add(lv0, b0)
+ logits: R.Tensor((3, 5), dtype="float32") =
R.nn.log_softmax(out, axis=-1)
+ loss: R.Tensor((), dtype="float32") =
R.nn.cross_entropy_with_logits(logits, label)
+ loss_adjoint: R.Tensor((), dtype="float32") =
R.ones(R.shape([]), dtype="float32")
+ lv: R.Tensor((), dtype="float32") = R.divide(loss_adjoint,
R.const(3, "float32"))
+ lv1: R.Tensor((), dtype="float32") = R.negative(lv)
+ logits_adjoint: R.Tensor((3, 5), dtype="float32") =
R.multiply(lv1, label)
+ lv2: R.Tensor((3, 1), dtype="float32") = R.sum(logits_adjoint,
axis=[-1], keepdims=True)
+ lv3: R.Tensor((3, 5), dtype="float32") = R.exp(logits)
+ lv4: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv3)
+ out_adjoint: R.Tensor((3, 5), dtype="float32") =
R.subtract(logits_adjoint, lv4)
+ lv0_adjoint: R.Tensor((3, 5), dtype="float32") = out_adjoint
+ lv5: R.Tensor((10, 3), dtype="float32") = R.permute_dims(x,
axes=[1, 0])
+ lv6: R.Tensor((10, 5), dtype="float32") = R.matmul(lv5,
lv0_adjoint, out_dtype="void")
+ w0_adjoint: R.Tensor((10, 5), dtype="float32") =
R.collapse_sum_to(lv6, R.shape([10, 5]))
+ b0_adjoint: R.Tensor((5,), dtype="float32") =
R.collapse_sum_to(out_adjoint, R.shape([5]))
+ R.output(loss, w0_adjoint, b0_adjoint)
+ return (loss, (w0_adjoint, b0_adjoint))
+
+ @R.function
+ def main(x: R.Tensor((3, 10), dtype="float32"), w0: R.Tensor((10, 5),
dtype="float32"), b0: R.Tensor((5,), dtype="float32"), label: R.Tensor((3, 5),
dtype="float32")) -> R.Tensor((), dtype="float32"):
+ with R.dataflow():
+ lv0: R.Tensor((3, 5), dtype="float32") = R.matmul(x, w0,
out_dtype="void")
+ out: R.Tensor((3, 5), dtype="float32") = R.add(lv0, b0)
+ logits: R.Tensor((3, 5), dtype="float32") =
R.nn.log_softmax(out, axis=-1)
+ loss: R.Tensor((), dtype="float32") =
R.nn.cross_entropy_with_logits(logits, label)
+ R.output(loss)
+ return loss
+ # fmt: on
+
+ After = relax.transform.Gradient("main",
require_grads=Before["main"].params[1:3])(Before)
+ assert_structural_equal(After, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform_gradient_numeric.py
b/tests/python/relax/test_transform_gradient_numeric.py
new file mode 100644
index 0000000000..7585ecf1f6
--- /dev/null
+++ b/tests/python/relax/test_transform_gradient_numeric.py
@@ -0,0 +1,192 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import numpy as np
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.relay.testing import rand
+from tvm.testing import assert_allclose
+from tvm.testing.utils import check_numerical_grads
+from tvm.script.parser import ir as I, relax as R
+from tvm.relax.transform import LegalizeOps
+
+
+def _legalize_and_build(mod, target, dev):
+ lowered_mod = LegalizeOps()(mod)
+ ex = relax.build(lowered_mod, target)
+ vm = relax.VirtualMachine(ex, dev)
+ return vm
+
+
[email protected]_targets("llvm")
+def test_manual_gradient(target, dev):
+ # The expression computed is sum((2x - 2y) * (y + z))
+ # the gradient of x is broadcast_to(2y + 2z, x.shape)
+ # the gradient of y is collapse_sum_to((2x - 4y - 2z), y.shape)
+ # the gradient of z is collapse_sum_to((2x - 2y), z.shape)
+ # the gradient of u is 0
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((3, 5), "float32"),
+ y: R.Tensor((5,), "float32"),
+ z: R.Tensor((5,), "float32"),
+ u: R.Tensor((5,), "float32"),
+ ):
+ with R.dataflow():
+ lv1 = R.add(x, x)
+ lv2 = R.subtract(lv1, y)
+ lv3 = R.subtract(lv2, y)
+ lv4 = R.add(y, z)
+ lv5 = R.multiply(lv3, lv4)
+ lv6 = R.sum(lv5)
+ R.output(lv6)
+ return lv6
+
+ After = relax.transform.Gradient("main")(Before)
+
+ args = [rand("float32", 3, 5), rand("float32", 5), rand("float32", 5),
rand("float32", 5)]
+ args_np = [x.numpy() for x in args]
+
+ vm = _legalize_and_build(After, target, dev)
+ output, grads = vm["main_adjoint"](*args)
+ output_np = np.sum((2 * args_np[0] - 2 * args_np[1]) * (args_np[1] +
args_np[2]))
+ assert_allclose(output.numpy(), output_np, atol=1e-4)
+
+ expected_grads_nd = [
+ (2 * args_np[1] + 2 * args_np[2]) * np.ones_like(args_np[0]),
+ np.sum((2 * args_np[0] - 4 * args_np[1] - 2 * args_np[2]), axis=0),
+ np.sum((2 * args_np[0] - 2 * args_np[1]), axis=0),
+ np.zeros_like(args_np[3]),
+ ]
+ for i, j in zip(grads, expected_grads_nd):
+ assert_allclose(i.numpy(), j, atol=1e-4)
+
+
[email protected]_targets("llvm")
+def test_mlp_blockbuilder(target, dev):
+ layers, in_size, out_size, hidden_size, batch_size = 3, 5, 5, 5, 4
+
+ input_list = [relax.Var("x", R.Tensor((batch_size, in_size), "float32"))]
+ w_list = (
+ [relax.Var("w_0", R.Tensor((in_size, hidden_size), "float32"))]
+ + [
+ relax.Var("w_" + str(i + 1), R.Tensor((hidden_size, hidden_size),
"float32"))
+ for i in range(layers - 2)
+ ]
+ + [relax.Var("w_" + str(layers - 1), R.Tensor((hidden_size, out_size),
"float32"))]
+ )
+ b_list = [
+ relax.Var("b_" + str(i), R.Tensor((hidden_size,), "float32")) for i in
range(layers - 1)
+ ] + [relax.Var("b_" + str(layers - 1), R.Tensor((out_size,), "float32"))]
+ label_list = [relax.Var("y", R.Tensor((batch_size,), "int64"))]
+ args_list = input_list + w_list + b_list + label_list
+
+ bb = relax.BlockBuilder()
+ with bb.function("MLP", args_list):
+ with bb.dataflow():
+ current = input_list[0]
+ for i in range(layers):
+ lv0 = bb.emit(R.matmul(current, w_list[i]))
+ lv1 = bb.emit(R.add(lv0, b_list[i]))
+ current = bb.emit(R.nn.relu(lv1) if i < layers - 1 else lv1)
+ logits = R.nn.log_softmax(current)
+ loss = bb.emit(R.nn.nll_loss(logits, label_list[0]))
+ gv0 = bb.emit_output(loss)
+ bb.emit_func_output(gv0)
+
+ Before = bb.get()
+ After = relax.transform.Gradient("MLP", w_list + b_list)(Before)
+ # Check numerical gradients equal
+ args = []
+ for arg in After["MLP_adjoint"].params:
+ shape = [int(l) for l in arg.struct_info.shape]
+ if arg.struct_info.dtype == "int64":
+ args.append(tvm.nd.array(np.random.randint(0, out_size,
size=shape).astype(np.int64)))
+ else: # float32
+ args.append(rand("float32", *shape))
+
+ vm_before = _legalize_and_build(Before, target, dev)
+ vm_after = _legalize_and_build(After, target, dev)
+ _, grad = vm_after["MLP_adjoint"](*args)
+
+ def func(*inputs):
+ loss = vm_before["MLP"](args[0], *[tvm.nd.array(i) for i in inputs],
args[-1])
+ return loss.numpy()
+
+ check_numerical_grads(func, [i.numpy() for i in args[1:-1]], [i.numpy()
for i in grad])
+
+
[email protected]_targets("llvm")
+def test_complex(target, dev):
+ cst = relax.const(np.ones((6,)), dtype="float32")
+ cst1 = relax.const(np.array(3), dtype="int64")
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor((6,), "float32"), y: R.Tensor((6, 3, 4),
"float32")):
+ with R.dataflow():
+ lv1 = R.split(x, 2)
+ lv2 = lv1[0]
+ lv3 = lv1[1]
+ lv4 = lv2 + lv3
+ lv5 = (lv4, lv3)
+ lv6 = R.concat(lv5)
+ lv7 = (x, x)
+ lv8 = R.concat(lv7)
+ lv9 = R.concat(lv7)
+ lv10 = R.add(lv8, lv9)
+ lv11 = R.split(lv10, 2)
+ lv12 = R.add(lv6, lv11[0])
+ lv13 = cst
+ lv14 = R.add(lv12, lv13)
+ lv15 = R.subtract(lv13, lv14)
+ lv16 = R.multiply(lv14, lv15)
+ lv17 = R.multiply(lv15, lv16)
+ lv18 = R.tanh(lv17)
+ lv19 = R.sigmoid(lv18)
+ lv20 = R.permute_dims(y, axes=[0, 2, 1])
+ lv21 = R.sigmoid(lv20)
+ lv22 = R.matmul(y, lv21)
+ lv23 = R.sum(lv22, axis=[1, 2])
+ lv24 = R.add(lv19, lv23)
+ lv25 = R.nn.log_softmax(lv24)
+ gv = R.nn.nll_loss(lv25, cst1)
+ R.output(gv)
+ return gv
+
+ After = relax.transform.Gradient("main")(Before)
+ args = []
+ for arg in After["main_adjoint"].params:
+ shape = [int(l) for l in arg.struct_info.shape]
+ args.append(rand("float32", *shape))
+
+ vm_before = _legalize_and_build(Before, target, dev)
+ vm_after = _legalize_and_build(After, target, dev)
+ _, grad = vm_after["main_adjoint"](*args)
+
+ def func(*inputs):
+ loss = vm_before["main"](*[tvm.nd.array(i) for i in inputs])
+ return loss.numpy()
+
+ check_numerical_grads(func, [i.numpy() for i in args], [i.numpy() for i in
grad])
+
+
+if __name__ == "__main__":
+ tvm.testing.main()