This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 61249b41ce [Relax][Transform] Provide callback versions of
LazyTransformParams (#16798)
61249b41ce is described below
commit 61249b41ce0f40ba50c582901c5932907708da89
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Apr 3 11:25:59 2024 -0500
[Relax][Transform] Provide callback versions of LazyTransformParams (#16798)
* [TIR][Analysis] Implemented tir.analysis.is_pure_function
This commit introduces two related utilities,
`tir.analysis.is_pure_function` and `tir.analysis.assert_pure_function`.
In contrast to the existing `tvm::tir::SideEffect`, which checks for
side effects on a for a `PrimExpr`, `is_pure_function` checks for side
effects for the function as a whole.
* [Transform] Implement relax.transform.ComputePrimValue
Prior to this commit, while expressions of type `DataType::Int(64)`
could be computed in the `relax.transform.VMShapeLower`, expressions
of any other type could not. This commit introduces
`relax.transform.ComputePrimValue`, which produces `PrimFunc`
subroutines to compute `PrimExpr` values of any dtype.
This functionality will allow boolean values to be computed based on
the symbolic values known at runtime.
* [Relax] Allow R.Prim('bool') in relax::If and assert_op
Prior to this commit, the condition used for `relax::If` node and the
`"relax.assert_op"` operator was required to be a scalar tensor. This
made it difficult to alter behavior based on a runtime shape
parameter. For example, delegating to a vectorized implementation
based on a whether a tensor shape is divisible by the vector size.
This commit adds support for expressions of type `R.Prim('bool')` as
the conditional for `relax::If` and `"relax.assert_op"`, to allow
these use cases.
* [Relax][Transform] Provide callback versions of LazyTransformParams
Prior to this commit, the `LazyTransformParams` function could be used
to load model parameters on demand. However, the function used to
load or set parameters needed to be registered within the global
registry of `PackedFunc`s. This PR provides `LazyGetInput` and
`LazySetOutput` transforms, which perform the lazy-loading through a
`R.Callable` callback argument, rather than through a
globally-registered `PackedFunc`.
* Reverse the order of parameters in fget_param
If `fget_param` accepts the parameter index first, and the parameter
name second, then an implementation with signauture and default values
of `def fget_param(index: int, name: Optional[str]=None)` could be
used as either the callback of `LazyGetInput`, or as the
globally-registered `"get_item"` for the existing
`LazyTransformParams`, which should make it easier to transition
between the two.
* lint fix
* Updates based on review comments
---
python/tvm/relax/transform/__init__.py | 2 +
python/tvm/relax/transform/transform.py | 80 +++++
src/relax/transform/lazy_transform_params.cc | 266 +++++++++++++++++
.../relax/test_transform_lazy_transform_params.py | 328 +++++++++++++++++++++
4 files changed, 676 insertions(+)
diff --git a/python/tvm/relax/transform/__init__.py
b/python/tvm/relax/transform/__init__.py
index 11e301c26c..5e76fff6bd 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -50,6 +50,8 @@ from .transform import (
InlinePrivateFunctions,
KillAfterLastUse,
LambdaLift,
+ LazyGetInput,
+ LazySetOutput,
LegalizeOps,
LiftTransformParams,
LowerAllocTensor,
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index dbc35d48d3..fa18cc672b 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -303,6 +303,86 @@ def LambdaLift() -> tvm.ir.transform.Pass:
return _ffi_api.LambdaLift()
+def LazyGetInput() -> tvm.ir.transform.Pass:
+ """A pass that requests inputs lazily.
+
+ In many cases, the size of the model weights exceeds the available
+ memory on a GPU. In these cases, a function that accepts all
+ model weights as arguments would not be able to be called. In
+ these cases, parameters must be loaded as they are required by the
+ function, and unloaded once they are no longer needed.
+
+ This pass mutates a function such that all model weights
+ (arguments after the first `func.attrs["num_input"]` arguments)
+ are loaded on demand. Rather than accepting the weights as
+ function arguments, the function accepts a callback argument,
+ which can load each parameter as needed. The callback accepts two
+ arguments, first the index of the model weight, and second the
+ name of the parameter. The callback should return the parameter
+ as specified.
+
+ .. code-block:: python
+
+ @R.function
+ def before(A: R.Tensor([16,32],"float32")):
+ ...
+
+ @R.function
+ def after(fget_param: R.Callable([R.Prim('int64'), R.Object],
R.Object)):
+ A_untyped = fget_param(0, R.str('A'))
+ A = R.match_cast(A_untyped, R.Tensor([16,32], "float32")
+ ...
+
+ Returns
+ -------
+ ret : tvm.ir.transform.Pass
+
+ """
+ return _ffi_api.LazyGetInput()
+
+
+def LazySetOutput() -> tvm.ir.transform.Pass:
+ """A pass that sets function outputs when available
+
+ In many cases, the size of the model weights exceeds the available
+ memory on a GPU. In these cases, a function that produces all
+ model weights as a single return value would not be able to be
+ called. In these cases, parameters must be returned as they are
+ produced, unloaded from the GPU (or saved to disk), before
+ producing additional outputs.
+
+ This pass mutates a function such that all outputs from a function
+ are returned when they are available. The function accepts an
+ additional callback argument, which is called with each output of
+ the function. The callback accepts two arguments, first the index
+ of the output tuple that was produced (or zero if the output is
+ not a tuple), and second the value itself.
+
+ .. code-block:: python
+
+ @R.function
+ def before(args):
+ ...
+ return (A, B)
+
+ @R.function
+ def after(args, fset_param: R.Callable([R.Prim('int64'), R.Object])):
+ ...
+ fset_param(0, A)
+ ...
+ fset_param(1, B)
+ ...
+ return ()
+
+
+ Returns
+ -------
+ ret : tvm.ir.transform.Pass
+
+ """
+ return _ffi_api.LazySetOutput()
+
+
def ConvertToDataflow(min_size: int = 2) -> tvm.ir.transform.Pass:
"""A pass that converts consecutive dataflow operations
inside binding blocks into dataflow blocks.
diff --git a/src/relax/transform/lazy_transform_params.cc
b/src/relax/transform/lazy_transform_params.cc
new file mode 100644
index 0000000000..21608af7db
--- /dev/null
+++ b/src/relax/transform/lazy_transform_params.cc
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*! \file src/relax/transform/lazy_transform_params.cc */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include <optional>
+#include <unordered_map>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+namespace {
+std::optional<int64_t> GetNumInputParams(const FunctionNode* func) {
+ if (auto opt_int_imm = func->GetAttr<IntImm>(attr::kNumInput)) {
+ int64_t num_input_params = opt_int_imm.value()->value;
+ CHECK_GE(num_input_params, 0) << "ValueError: "
+ << "Annotation for attr::kNumInput (\"" <<
attr::kNumInput
+ << "\") must be non-negative, but was " <<
num_input_params;
+ CHECK_LE(static_cast<size_t>(num_input_params), func->params.size())
+ << "ValueError: "
+ << "Annotation for attr::kNumInput (\"" << attr::kNumInput << "\")
specifies "
+ << num_input_params << " parameters to be provided at runtime, "
+ << "but the function only accepts " << func->params.size() << "
parameters in total";
+ return num_input_params;
+ } else {
+ return std::nullopt;
+ }
+}
+
+class LazyInputMutator : public ExprMutator {
+ public:
+ Expr VisitExpr_(const FunctionNode* func) override {
+ if (plan_.has_value()) {
+ return ExprMutator::VisitExpr_(func);
+ }
+
+ int64_t num_input_params = GetNumInputParams(func).value_or(0);
+
+ std::unordered_map<Var, size_t, ObjectPtrHash, ObjectPtrEqual>
param_lookup;
+ for (size_t i = num_input_params; i < func->params.size(); i++) {
+ param_lookup.insert({func->params[i], i - num_input_params});
+ }
+
+ Var fget_param("fget_param",
+ FuncStructInfo({PrimStructInfo(DataType::Int(64)),
ObjectStructInfo()},
+ ObjectStructInfo()));
+
+ Array<Var> new_params(func->params.begin(), func->params.begin() +
num_input_params);
+ new_params.push_back(fget_param);
+
+ auto node = GetRef<Function>(func);
+ node.CopyOnWrite()->params = new_params;
+ node = WithAttr(node, attr::kNumInput, Integer(num_input_params + 1));
+
+ plan_ = FunctionPlan{std::move(param_lookup), fget_param};
+ auto output = Downcast<Function>(ExprMutator::VisitExpr_(node.get()));
+ plan_.reset();
+ return output;
+ }
+
+ Expr VisitExpr_(const VarNode* op) override {
+ if (plan_) {
+ Var var = GetRef<Var>(op);
+ if (auto it = plan_->param_lookup.find(var); it !=
plan_->param_lookup.end()) {
+ auto untyped =
+ builder_->Emit(relax::Call(plan_->fget_param,
+ {
+ PrimValue(IntImm(DataType::Int(64),
it->second)),
+ StringImm(var->name_hint()),
+ }),
+ var->name_hint() + "_untyped");
+ return builder_->EmitMatchCast(untyped, GetStructInfo(var),
var->name_hint());
+ }
+ }
+
+ return ExprMutator::VisitExpr_(op);
+ }
+
+ private:
+ struct FunctionPlan {
+ std::unordered_map<Var, size_t, ObjectPtrHash, ObjectPtrEqual>
param_lookup;
+ Expr fget_param;
+ };
+ std::optional<FunctionPlan> plan_;
+};
+
+class LazyOutputMutator : public ExprMutator {
+ public:
+ Expr VisitExpr_(const FunctionNode* func) override {
+ if (plan_.has_value()) {
+ return ExprMutator::VisitExpr_(func);
+ }
+
+ std::unordered_map<Var, std::vector<size_t>, ObjectPtrHash,
ObjectPtrEqual> output_lookup;
+ std::vector<std::tuple<size_t, Expr>> inline_outputs;
+ auto define_lookup = [&](size_t output_index, Expr output_value) {
+ if (auto var = output_value.as<Var>()) {
+ output_lookup[var.value()].push_back(output_index);
+ } else {
+ inline_outputs.push_back({output_index, output_value});
+ }
+ };
+
+ auto func_body = Downcast<SeqExpr>(func->body);
+ if (auto tuple_output = func_body->body.as<TupleNode>()) {
+ for (size_t i = 0; i < tuple_output->fields.size(); i++) {
+ define_lookup(i, tuple_output->fields[i]);
+ }
+ } else {
+ define_lookup(0, func_body->body);
+ }
+
+ Var fset_output("fset_output",
+ FuncStructInfo({PrimStructInfo(DataType::Int(64)),
ObjectStructInfo()},
+ TupleStructInfo(Array<StructInfo>{})));
+ plan_ = FunctionPlan{std::move(output_lookup), fset_output};
+
+ std::optional<int64_t> num_input_params = GetNumInputParams(func);
+
+ auto new_params = func->params;
+ new_params.insert(new_params.begin() +
num_input_params.value_or(func->params.size()),
+ fset_output);
+
+ BindingBlock start_of_func = [&]() {
+ Array<Binding> propagated_params;
+ for (auto param : func->params) {
+ GenerateSetOutputCalls(param, [&](const auto& fset_output_call) {
+ Var void_output("_void", TupleStructInfo(Array<StructInfo>{}));
+ propagated_params.push_back(VarBinding(void_output,
fset_output_call));
+ });
+ }
+ return BindingBlock(propagated_params);
+ }();
+ BindingBlock end_of_func = [&]() {
+ Array<Binding> propagated_params;
+ for (const auto& [output_index, expr] : inline_outputs) {
+ Call fset_output_call(fset_output,
+ {PrimValue(IntImm(DataType::Int(64),
output_index)), expr});
+ Var void_output("_void", TupleStructInfo(Array<StructInfo>{}));
+ propagated_params.push_back(VarBinding(void_output, fset_output_call));
+ }
+ return BindingBlock(propagated_params);
+ }();
+
+ Array<BindingBlock> new_blocks = func_body->blocks;
+ new_blocks.insert(new_blocks.begin(), start_of_func);
+ new_blocks.push_back(end_of_func);
+ Expr new_body = SeqExpr(new_blocks, Tuple(Array<Expr>{}));
+
+ auto node = GetRef<Function>(func);
+ {
+ auto write_ptr = node.CopyOnWrite();
+ write_ptr->params = new_params;
+ write_ptr->body = new_body;
+ }
+ if (num_input_params.has_value()) {
+ node = WithAttr(node, attr::kNumInput, Integer(num_input_params.value()
+ 1));
+ }
+
+ auto output = Downcast<Function>(ExprMutator::VisitExpr_(node.get()));
+ plan_.reset();
+ return output;
+ }
+
+ void VisitBinding(const Binding& binding) override {
+ ExprMutator::VisitBinding(binding);
+ GenerateSetOutputCalls(binding->var, [this](const auto& fset_output_call) {
+ builder_->Emit(fset_output_call, "_void");
+ });
+ }
+
+ private:
+ template <typename Callback>
+ void GenerateSetOutputCalls(const Var& var, Callback callback) {
+ if (plan_.has_value()) {
+ if (auto it = plan_->output_lookup.find(var); it !=
plan_->output_lookup.end()) {
+ for (auto output_index : it->second) {
+ callback(
+ Call(plan_->fset_output, {PrimValue(IntImm(DataType::Int(64),
output_index)), var}));
+ }
+ }
+ }
+ }
+
+ struct FunctionPlan {
+ std::unordered_map<Var, std::vector<size_t>, ObjectPtrHash,
ObjectPtrEqual> output_lookup;
+ Expr fset_output;
+ };
+ std::optional<FunctionPlan> plan_;
+};
+} // namespace
+
+Function WithLazyInputs(Function func) {
+ LazyInputMutator mutator;
+
+ func = Downcast<Function>(mutator.VisitExpr(func));
+ func = Downcast<Function>(EliminateCommonSubexpr(func));
+ func = Downcast<Function>(RemoveAllUnused(func));
+ return func;
+}
+
+Function WithLazyOutputs(Function func) {
+ LazyOutputMutator mutator;
+
+ func = Downcast<Function>(mutator.VisitExpr(func));
+ return func;
+}
+
+namespace transform {
+
+Pass LazyGetInput() {
+ auto pass_func = [](Function func, IRModule, PassContext) -> Function {
+ if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined()) {
+ return func;
+ }
+ return WithLazyInputs(func);
+ };
+ return CreateFunctionPass(/*pass_function=*/pass_func,
+ /*opt_level=*/0,
+ /*pass_name=*/"LazyGetInput",
+ /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.LazyGetInput").set_body_typed(LazyGetInput);
+
+Pass LazySetOutput() {
+ auto pass_func = [](Function func, IRModule, PassContext) -> Function {
+ if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined()) {
+ return func;
+ }
+ return WithLazyOutputs(func);
+ };
+ return CreateFunctionPass(/*pass_function=*/pass_func,
+ /*opt_level=*/0,
+ /*pass_name=*/"LazySetOutput",
+ /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.LazySetOutput").set_body_typed(LazySetOutput);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py
b/tests/python/relax/test_transform_lazy_transform_params.py
index b16de32ceb..833cbd460c 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -824,5 +824,333 @@ def test_params_without_tuple_with_symbolic_var():
tvm.ir.assert_structural_equal(After, Expected)
+def test_get_item_callback():
+ @I.ir_module
+ class Before:
+ @R.function
+ def transform_params(A: R.Tensor([16, 16], "float32"), B:
R.Tensor([16, 16], "float32")):
+ C = R.multiply(A, R.const(2, "float32"))
+ D = R.add(C, B)
+ return (D, B)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(fget_param: R.Callable([R.Prim("int64"),
R.Object], R.Object)):
+ R.func_attr({"num_input": 1})
+ A = fget_param(R.prim_value(0), R.str("A"))
+ A = R.match_cast(A, R.Tensor([16, 16], "float32"))
+ C = R.multiply(A, R.const(2, "float32"))
+
+ B = fget_param(R.prim_value(1), R.str("B"))
+ B = R.match_cast(B, R.Tensor([16, 16], "float32"))
+ D = R.add(C, B)
+ return (D, B)
+
+ After = relax.transform.LazyGetInput()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_get_item_callback_num_attrs():
+ @I.ir_module
+ class Before:
+ @R.function(pure=False)
+ def transform_params(
+ rank_arg: R.Prim(value="rank"),
+ world_size_arg: R.Prim(value="world_size"),
+ weight_A: R.Tensor([16, 64], "float32"),
+ weight_B: R.Tensor([1024, 2048], "float32"),
+ ):
+ R.func_attr({"num_input": 2})
+
+ rank = T.int64()
+ world_size = T.int64()
+
+ _ = R.assert_op(
+ R.prim_value(16 % world_size == 0),
+ [R.prim_value(16), R.prim_value(world_size)],
+ format=(
+ "World size must evenly divide A.shape[0] ({}), "
+ "but received world size of {}."
+ ),
+ )
+ weight_A = R.strided_slice(
+ weight_A,
+ axes=[0],
+ begin=[rank * 16 // world_size],
+ end=[(rank + 1) * 16 // world_size],
+ )
+
+ _ = R.assert_op(
+ R.prim_value(2048 % world_size == 0),
+ [R.prim_value(2048), R.prim_value(world_size)],
+ format=(
+ "World size must evenly divide B.shape[1] ({}), "
+ "but received world size of {}."
+ ),
+ )
+ weight_B = R.strided_slice(
+ weight_B,
+ axes=[1],
+ begin=[rank * 2048 // world_size],
+ end=[(rank + 1) * 2048 // world_size],
+ )
+
+ return (weight_A, weight_B)
+
+ @I.ir_module
+ class Expected:
+ @R.function(pure=False)
+ def transform_params(
+ rank_arg: R.Prim(value="rank"),
+ world_size_arg: R.Prim(value="world_size"),
+ fget_item: R.Callable([R.Prim("int64"), R.Object], R.Object),
+ ):
+ R.func_attr({"num_input": 3})
+
+ rank = T.int64()
+ world_size = T.int64()
+
+ _ = R.assert_op(
+ R.prim_value(16 % world_size == 0),
+ [R.prim_value(16), R.prim_value(world_size)],
+ format=(
+ "World size must evenly divide A.shape[0] ({}), "
+ "but received world size of {}."
+ ),
+ )
+ weight_A = fget_item(R.prim_value(0), R.str("weight_A"))
+ weight_A = R.match_cast(weight_A, R.Tensor([16, 64], "float32"))
+ weight_A = R.strided_slice(
+ weight_A,
+ axes=[0],
+ begin=[rank * 16 // world_size],
+ end=[(rank + 1) * 16 // world_size],
+ )
+
+ _ = R.assert_op(
+ R.prim_value(2048 % world_size == 0),
+ [R.prim_value(2048), R.prim_value(world_size)],
+ format=(
+ "World size must evenly divide B.shape[1] ({}), "
+ "but received world size of {}."
+ ),
+ )
+ weight_B = fget_item(R.prim_value(1), R.str("weight_B"))
+ weight_B = R.match_cast(weight_B, R.Tensor([1024, 2048],
"float32"))
+ weight_B = R.strided_slice(
+ weight_B,
+ axes=[1],
+ begin=[rank * 2048 // world_size],
+ end=[(rank + 1) * 2048 // world_size],
+ )
+
+ return (weight_A, weight_B)
+
+ After = relax.transform.LazyGetInput()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_set_output_callback():
+ """fset_output is called for each element of the output tuple
+
+ The call is placed immediately after the corresponding
+ `VarBinding`.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def transform_params(A: R.Tensor([16, 16], "float32"), B:
R.Tensor([16, 16], "float32")):
+ C = R.multiply(A, R.const(2, "float32"))
+ D = R.add(C, B)
+ return (D, C)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ A: R.Tensor([16, 16], "float32"),
+ B: R.Tensor([16, 16], "float32"),
+ fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
+ ):
+ C = R.multiply(A, R.const(2, "float32"))
+ fset_output(R.prim_value(1), C)
+ D = R.add(C, B)
+ fset_output(R.prim_value(0), D)
+ return R.tuple()
+
+ After = relax.transform.LazySetOutput()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_set_output_callback_of_param():
+ """fset_output may need to be called for parameters
+
+ A function parameter does not have a `VarBinding`. If a parameter
+ is returned in the output tuple, the `fset_output` call is
+ generated at the beginning of the function.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def transform_params(A: R.Tensor([16, 16], "float32"), B:
R.Tensor([16, 16], "float32")):
+ C = R.multiply(A, R.const(2, "float32"))
+ D = R.add(C, B)
+ return (D, B)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ A: R.Tensor([16, 16], "float32"),
+ B: R.Tensor([16, 16], "float32"),
+ fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
+ ):
+ fset_output(R.prim_value(1), B)
+ C = R.multiply(A, R.const(2, "float32"))
+ D = R.add(C, B)
+ fset_output(R.prim_value(0), D)
+ return R.tuple()
+
+ After = relax.transform.LazySetOutput()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_set_output_callback_num_input():
+ """The parameter transformation may have other runtime parameters
+
+ The new `fset_output` parameter is placed after the other runtime
+ parameters, before any model weights.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def transform_params(A: R.Tensor([16, 16], "float32"), B:
R.Tensor([16, 16], "float32")):
+ R.func_attr({"num_input": 1})
+ C = R.multiply(A, R.const(2, "float32"))
+ D = R.add(C, B)
+ return (D, B)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ A: R.Tensor([16, 16], "float32"),
+ fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
+ B: R.Tensor([16, 16], "float32"),
+ ):
+ R.func_attr({"num_input": 2})
+ fset_output(R.prim_value(1), B)
+ C = R.multiply(A, R.const(2, "float32"))
+ D = R.add(C, B)
+ fset_output(R.prim_value(0), D)
+ return R.tuple()
+
+ After = relax.transform.LazySetOutput()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_set_output_callback_with_duplicate_output():
+ """fset_output may be called more than once for a variable
+
+ A variable may occur multiple times in the output tuple. The
+ `fset_output` callback should be called once for each tuple
+ element, even if they reuse the same variable.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def transform_params(A: R.Tensor([16, 16], "float32"), B:
R.Tensor([16, 16], "float32")):
+ C = R.multiply(A, R.const(2, "float32"))
+ D = R.add(C, B)
+ return (D, D)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ A: R.Tensor([16, 16], "float32"),
+ B: R.Tensor([16, 16], "float32"),
+ fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
+ ):
+ C = R.multiply(A, R.const(2, "float32"))
+ D = R.add(C, B)
+ fset_output(R.prim_value(0), D)
+ fset_output(R.prim_value(1), D)
+ return R.tuple()
+
+ After = relax.transform.LazySetOutput()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_set_output_callback_with_inline_const():
+ """fset_output may be called for inline objects
+
+ The return tuple may contain inline leaf nodes, such as
+ `relax.PrimValue` or `relax.Constant`. A call to `fset_output`
+ must be generated, even though they do not have an associated
+ `relax.VarBinding`.
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def transform_params(A: R.Tensor([16, 16], "float32"), B:
R.Tensor([16, 16], "float32")):
+ C = R.multiply(A, R.const(2, "float32"))
+ D = R.add(C, B)
+ return (C, D, R.prim_value(42), R.const(17.5, "float16"))
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ A: R.Tensor([16, 16], "float32"),
+ B: R.Tensor([16, 16], "float32"),
+ fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
+ ):
+ C = R.multiply(A, R.const(2, "float32"))
+ fset_output(R.prim_value(0), C)
+ D = R.add(C, B)
+ fset_output(R.prim_value(1), D)
+ fset_output(R.prim_value(2), R.prim_value(42))
+ fset_output(R.prim_value(3), R.const(17.5, "float16"))
+ return R.tuple()
+
+ After = relax.transform.LazySetOutput()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_set_output_callback_with_non_tuple_output():
+ """Non-tuple outputs produce a single call to fset_output"""
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def transform_params(A: R.Tensor([16, 16], "float32"), B:
R.Tensor([16, 16], "float32")):
+ C = R.multiply(A, R.const(2, "float32"))
+ D = R.add(C, B)
+ return D
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ A: R.Tensor([16, 16], "float32"),
+ B: R.Tensor([16, 16], "float32"),
+ fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
+ ):
+ C = R.multiply(A, R.const(2, "float32"))
+ D = R.add(C, B)
+ fset_output(R.prim_value(0), D)
+ return R.tuple()
+
+ After = relax.transform.LazySetOutput()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()