This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 61249b41ce [Relax][Transform] Provide callback versions of 
LazyTransformParams (#16798)
61249b41ce is described below

commit 61249b41ce0f40ba50c582901c5932907708da89
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Apr 3 11:25:59 2024 -0500

    [Relax][Transform] Provide callback versions of LazyTransformParams (#16798)
    
    * [TIR][Analysis] Implemented tir.analysis.is_pure_function
    
    This commit introduces two related utilities,
    `tir.analysis.is_pure_function` and `tir.analysis.assert_pure_function`.
    In contrast to the existing `tvm::tir::SideEffect`, which checks for
    side effects on a for a `PrimExpr`, `is_pure_function` checks for side
    effects for the function as a whole.
    
    * [Transform] Implement relax.transform.ComputePrimValue
    
    Prior to this commit, while expressions of type `DataType::Int(64)`
    could be computed in the `relax.transform.VMShapeLower`, expressions
    of any other type could not.  This commit introduces
    `relax.transform.ComputePrimValue`, which produces `PrimFunc`
    subroutines to compute `PrimExpr` values of any dtype.
    
    This functionality will allow boolean values to be computed based on
    the symbolic values known at runtime.
    
    * [Relax] Allow R.Prim('bool') in relax::If and assert_op
    
    Prior to this commit, the condition used for `relax::If` node and the
    `"relax.assert_op"` operator was required to be a scalar tensor.  This
    made it difficult to alter behavior based on a runtime shape
    parameter.  For example, delegating to a vectorized implementation
    based on a whether a tensor shape is divisible by the vector size.
    
    This commit adds support for expressions of type `R.Prim('bool')` as
    the conditional for `relax::If` and `"relax.assert_op"`, to allow
    these use cases.
    
    * [Relax][Transform] Provide callback versions of LazyTransformParams
    
    Prior to this commit, the `LazyTransformParams` function could be used
    to load model parameters on demand.  However, the function used to
    load or set parameters needed to be registered within the global
    registry of `PackedFunc`s.  This PR provides `LazyGetInput` and
    `LazySetOutput` transforms, which perform the lazy-loading through a
    `R.Callable` callback argument, rather than through a
    globally-registered `PackedFunc`.
    
    * Reverse the order of parameters in fget_param
    
    If `fget_param` accepts the parameter index first, and the parameter
    name second, then an implementation with signauture and default values
    of `def fget_param(index: int, name: Optional[str]=None)` could be
    used as either the callback of `LazyGetInput`, or as the
    globally-registered `"get_item"` for the existing
    `LazyTransformParams`, which should make it easier to transition
    between the two.
    
    * lint fix
    
    * Updates based on review comments
---
 python/tvm/relax/transform/__init__.py             |   2 +
 python/tvm/relax/transform/transform.py            |  80 +++++
 src/relax/transform/lazy_transform_params.cc       | 266 +++++++++++++++++
 .../relax/test_transform_lazy_transform_params.py  | 328 +++++++++++++++++++++
 4 files changed, 676 insertions(+)

diff --git a/python/tvm/relax/transform/__init__.py 
b/python/tvm/relax/transform/__init__.py
index 11e301c26c..5e76fff6bd 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -50,6 +50,8 @@ from .transform import (
     InlinePrivateFunctions,
     KillAfterLastUse,
     LambdaLift,
+    LazyGetInput,
+    LazySetOutput,
     LegalizeOps,
     LiftTransformParams,
     LowerAllocTensor,
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index dbc35d48d3..fa18cc672b 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -303,6 +303,86 @@ def LambdaLift() -> tvm.ir.transform.Pass:
     return _ffi_api.LambdaLift()
 
 
+def LazyGetInput() -> tvm.ir.transform.Pass:
+    """A pass that requests inputs lazily.
+
+    In many cases, the size of the model weights exceeds the available
+    memory on a GPU.  In these cases, a function that accepts all
+    model weights as arguments would not be able to be called.  In
+    these cases, parameters must be loaded as they are required by the
+    function, and unloaded once they are no longer needed.
+
+    This pass mutates a function such that all model weights
+    (arguments after the first `func.attrs["num_input"]` arguments)
+    are loaded on demand.  Rather than accepting the weights as
+    function arguments, the function accepts a callback argument,
+    which can load each parameter as needed.  The callback accepts two
+    arguments, first the index of the model weight, and second the
+    name of the parameter.  The callback should return the parameter
+    as specified.
+
+    .. code-block:: python
+
+        @R.function
+        def before(A: R.Tensor([16,32],"float32")):
+            ...
+
+        @R.function
+        def after(fget_param: R.Callable([R.Prim('int64'), R.Object], 
R.Object)):
+            A_untyped = fget_param(0, R.str('A'))
+            A = R.match_cast(A_untyped, R.Tensor([16,32], "float32")
+            ...
+
+    Returns
+    -------
+    ret : tvm.ir.transform.Pass
+
+    """
+    return _ffi_api.LazyGetInput()
+
+
+def LazySetOutput() -> tvm.ir.transform.Pass:
+    """A pass that sets function outputs when available
+
+    In many cases, the size of the model weights exceeds the available
+    memory on a GPU.  In these cases, a function that produces all
+    model weights as a single return value would not be able to be
+    called.  In these cases, parameters must be returned as they are
+    produced, unloaded from the GPU (or saved to disk), before
+    producing additional outputs.
+
+    This pass mutates a function such that all outputs from a function
+    are returned when they are available.  The function accepts an
+    additional callback argument, which is called with each output of
+    the function.  The callback accepts two arguments, first the index
+    of the output tuple that was produced (or zero if the output is
+    not a tuple), and second the value itself.
+
+    .. code-block:: python
+
+        @R.function
+        def before(args):
+            ...
+            return (A, B)
+
+        @R.function
+        def after(args, fset_param: R.Callable([R.Prim('int64'), R.Object])):
+            ...
+            fset_param(0, A)
+            ...
+            fset_param(1, B)
+            ...
+            return ()
+
+
+    Returns
+    -------
+    ret : tvm.ir.transform.Pass
+
+    """
+    return _ffi_api.LazySetOutput()
+
+
 def ConvertToDataflow(min_size: int = 2) -> tvm.ir.transform.Pass:
     """A pass that converts consecutive dataflow operations
     inside binding blocks into dataflow blocks.
diff --git a/src/relax/transform/lazy_transform_params.cc 
b/src/relax/transform/lazy_transform_params.cc
new file mode 100644
index 0000000000..21608af7db
--- /dev/null
+++ b/src/relax/transform/lazy_transform_params.cc
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*! \file src/relax/transform/lazy_transform_params.cc */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include <optional>
+#include <unordered_map>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+namespace {
+std::optional<int64_t> GetNumInputParams(const FunctionNode* func) {
+  if (auto opt_int_imm = func->GetAttr<IntImm>(attr::kNumInput)) {
+    int64_t num_input_params = opt_int_imm.value()->value;
+    CHECK_GE(num_input_params, 0) << "ValueError: "
+                                  << "Annotation for attr::kNumInput (\"" << 
attr::kNumInput
+                                  << "\") must be non-negative, but was " << 
num_input_params;
+    CHECK_LE(static_cast<size_t>(num_input_params), func->params.size())
+        << "ValueError: "
+        << "Annotation for attr::kNumInput (\"" << attr::kNumInput << "\") 
specifies "
+        << num_input_params << " parameters to be provided at runtime, "
+        << "but the function only accepts " << func->params.size() << " 
parameters in total";
+    return num_input_params;
+  } else {
+    return std::nullopt;
+  }
+}
+
+class LazyInputMutator : public ExprMutator {
+ public:
+  Expr VisitExpr_(const FunctionNode* func) override {
+    if (plan_.has_value()) {
+      return ExprMutator::VisitExpr_(func);
+    }
+
+    int64_t num_input_params = GetNumInputParams(func).value_or(0);
+
+    std::unordered_map<Var, size_t, ObjectPtrHash, ObjectPtrEqual> 
param_lookup;
+    for (size_t i = num_input_params; i < func->params.size(); i++) {
+      param_lookup.insert({func->params[i], i - num_input_params});
+    }
+
+    Var fget_param("fget_param",
+                   FuncStructInfo({PrimStructInfo(DataType::Int(64)), 
ObjectStructInfo()},
+                                  ObjectStructInfo()));
+
+    Array<Var> new_params(func->params.begin(), func->params.begin() + 
num_input_params);
+    new_params.push_back(fget_param);
+
+    auto node = GetRef<Function>(func);
+    node.CopyOnWrite()->params = new_params;
+    node = WithAttr(node, attr::kNumInput, Integer(num_input_params + 1));
+
+    plan_ = FunctionPlan{std::move(param_lookup), fget_param};
+    auto output = Downcast<Function>(ExprMutator::VisitExpr_(node.get()));
+    plan_.reset();
+    return output;
+  }
+
+  Expr VisitExpr_(const VarNode* op) override {
+    if (plan_) {
+      Var var = GetRef<Var>(op);
+      if (auto it = plan_->param_lookup.find(var); it != 
plan_->param_lookup.end()) {
+        auto untyped =
+            builder_->Emit(relax::Call(plan_->fget_param,
+                                       {
+                                           PrimValue(IntImm(DataType::Int(64), 
it->second)),
+                                           StringImm(var->name_hint()),
+                                       }),
+                           var->name_hint() + "_untyped");
+        return builder_->EmitMatchCast(untyped, GetStructInfo(var), 
var->name_hint());
+      }
+    }
+
+    return ExprMutator::VisitExpr_(op);
+  }
+
+ private:
+  struct FunctionPlan {
+    std::unordered_map<Var, size_t, ObjectPtrHash, ObjectPtrEqual> 
param_lookup;
+    Expr fget_param;
+  };
+  std::optional<FunctionPlan> plan_;
+};
+
+class LazyOutputMutator : public ExprMutator {
+ public:
+  Expr VisitExpr_(const FunctionNode* func) override {
+    if (plan_.has_value()) {
+      return ExprMutator::VisitExpr_(func);
+    }
+
+    std::unordered_map<Var, std::vector<size_t>, ObjectPtrHash, 
ObjectPtrEqual> output_lookup;
+    std::vector<std::tuple<size_t, Expr>> inline_outputs;
+    auto define_lookup = [&](size_t output_index, Expr output_value) {
+      if (auto var = output_value.as<Var>()) {
+        output_lookup[var.value()].push_back(output_index);
+      } else {
+        inline_outputs.push_back({output_index, output_value});
+      }
+    };
+
+    auto func_body = Downcast<SeqExpr>(func->body);
+    if (auto tuple_output = func_body->body.as<TupleNode>()) {
+      for (size_t i = 0; i < tuple_output->fields.size(); i++) {
+        define_lookup(i, tuple_output->fields[i]);
+      }
+    } else {
+      define_lookup(0, func_body->body);
+    }
+
+    Var fset_output("fset_output",
+                    FuncStructInfo({PrimStructInfo(DataType::Int(64)), 
ObjectStructInfo()},
+                                   TupleStructInfo(Array<StructInfo>{})));
+    plan_ = FunctionPlan{std::move(output_lookup), fset_output};
+
+    std::optional<int64_t> num_input_params = GetNumInputParams(func);
+
+    auto new_params = func->params;
+    new_params.insert(new_params.begin() + 
num_input_params.value_or(func->params.size()),
+                      fset_output);
+
+    BindingBlock start_of_func = [&]() {
+      Array<Binding> propagated_params;
+      for (auto param : func->params) {
+        GenerateSetOutputCalls(param, [&](const auto& fset_output_call) {
+          Var void_output("_void", TupleStructInfo(Array<StructInfo>{}));
+          propagated_params.push_back(VarBinding(void_output, 
fset_output_call));
+        });
+      }
+      return BindingBlock(propagated_params);
+    }();
+    BindingBlock end_of_func = [&]() {
+      Array<Binding> propagated_params;
+      for (const auto& [output_index, expr] : inline_outputs) {
+        Call fset_output_call(fset_output,
+                              {PrimValue(IntImm(DataType::Int(64), 
output_index)), expr});
+        Var void_output("_void", TupleStructInfo(Array<StructInfo>{}));
+        propagated_params.push_back(VarBinding(void_output, fset_output_call));
+      }
+      return BindingBlock(propagated_params);
+    }();
+
+    Array<BindingBlock> new_blocks = func_body->blocks;
+    new_blocks.insert(new_blocks.begin(), start_of_func);
+    new_blocks.push_back(end_of_func);
+    Expr new_body = SeqExpr(new_blocks, Tuple(Array<Expr>{}));
+
+    auto node = GetRef<Function>(func);
+    {
+      auto write_ptr = node.CopyOnWrite();
+      write_ptr->params = new_params;
+      write_ptr->body = new_body;
+    }
+    if (num_input_params.has_value()) {
+      node = WithAttr(node, attr::kNumInput, Integer(num_input_params.value() 
+ 1));
+    }
+
+    auto output = Downcast<Function>(ExprMutator::VisitExpr_(node.get()));
+    plan_.reset();
+    return output;
+  }
+
+  void VisitBinding(const Binding& binding) override {
+    ExprMutator::VisitBinding(binding);
+    GenerateSetOutputCalls(binding->var, [this](const auto& fset_output_call) {
+      builder_->Emit(fset_output_call, "_void");
+    });
+  }
+
+ private:
+  template <typename Callback>
+  void GenerateSetOutputCalls(const Var& var, Callback callback) {
+    if (plan_.has_value()) {
+      if (auto it = plan_->output_lookup.find(var); it != 
plan_->output_lookup.end()) {
+        for (auto output_index : it->second) {
+          callback(
+              Call(plan_->fset_output, {PrimValue(IntImm(DataType::Int(64), 
output_index)), var}));
+        }
+      }
+    }
+  }
+
+  struct FunctionPlan {
+    std::unordered_map<Var, std::vector<size_t>, ObjectPtrHash, 
ObjectPtrEqual> output_lookup;
+    Expr fset_output;
+  };
+  std::optional<FunctionPlan> plan_;
+};
+}  // namespace
+
+Function WithLazyInputs(Function func) {
+  LazyInputMutator mutator;
+
+  func = Downcast<Function>(mutator.VisitExpr(func));
+  func = Downcast<Function>(EliminateCommonSubexpr(func));
+  func = Downcast<Function>(RemoveAllUnused(func));
+  return func;
+}
+
+Function WithLazyOutputs(Function func) {
+  LazyOutputMutator mutator;
+
+  func = Downcast<Function>(mutator.VisitExpr(func));
+  return func;
+}
+
+namespace transform {
+
+Pass LazyGetInput() {
+  auto pass_func = [](Function func, IRModule, PassContext) -> Function {
+    if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined()) {
+      return func;
+    }
+    return WithLazyInputs(func);
+  };
+  return CreateFunctionPass(/*pass_function=*/pass_func,
+                            /*opt_level=*/0,
+                            /*pass_name=*/"LazyGetInput",
+                            /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.LazyGetInput").set_body_typed(LazyGetInput);
+
+Pass LazySetOutput() {
+  auto pass_func = [](Function func, IRModule, PassContext) -> Function {
+    if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined()) {
+      return func;
+    }
+    return WithLazyOutputs(func);
+  };
+  return CreateFunctionPass(/*pass_function=*/pass_func,
+                            /*opt_level=*/0,
+                            /*pass_name=*/"LazySetOutput",
+                            /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.LazySetOutput").set_body_typed(LazySetOutput);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py 
b/tests/python/relax/test_transform_lazy_transform_params.py
index b16de32ceb..833cbd460c 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -824,5 +824,333 @@ def test_params_without_tuple_with_symbolic_var():
     tvm.ir.assert_structural_equal(After, Expected)
 
 
+def test_get_item_callback():
+    @I.ir_module
+    class Before:
+        @R.function
+        def transform_params(A: R.Tensor([16, 16], "float32"), B: 
R.Tensor([16, 16], "float32")):
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            return (D, B)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def transform_params(fget_param: R.Callable([R.Prim("int64"), 
R.Object], R.Object)):
+            R.func_attr({"num_input": 1})
+            A = fget_param(R.prim_value(0), R.str("A"))
+            A = R.match_cast(A, R.Tensor([16, 16], "float32"))
+            C = R.multiply(A, R.const(2, "float32"))
+
+            B = fget_param(R.prim_value(1), R.str("B"))
+            B = R.match_cast(B, R.Tensor([16, 16], "float32"))
+            D = R.add(C, B)
+            return (D, B)
+
+    After = relax.transform.LazyGetInput()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_get_item_callback_num_attrs():
+    @I.ir_module
+    class Before:
+        @R.function(pure=False)
+        def transform_params(
+            rank_arg: R.Prim(value="rank"),
+            world_size_arg: R.Prim(value="world_size"),
+            weight_A: R.Tensor([16, 64], "float32"),
+            weight_B: R.Tensor([1024, 2048], "float32"),
+        ):
+            R.func_attr({"num_input": 2})
+
+            rank = T.int64()
+            world_size = T.int64()
+
+            _ = R.assert_op(
+                R.prim_value(16 % world_size == 0),
+                [R.prim_value(16), R.prim_value(world_size)],
+                format=(
+                    "World size must evenly divide A.shape[0] ({}), "
+                    "but received world size of {}."
+                ),
+            )
+            weight_A = R.strided_slice(
+                weight_A,
+                axes=[0],
+                begin=[rank * 16 // world_size],
+                end=[(rank + 1) * 16 // world_size],
+            )
+
+            _ = R.assert_op(
+                R.prim_value(2048 % world_size == 0),
+                [R.prim_value(2048), R.prim_value(world_size)],
+                format=(
+                    "World size must evenly divide B.shape[1] ({}), "
+                    "but received world size of {}."
+                ),
+            )
+            weight_B = R.strided_slice(
+                weight_B,
+                axes=[1],
+                begin=[rank * 2048 // world_size],
+                end=[(rank + 1) * 2048 // world_size],
+            )
+
+            return (weight_A, weight_B)
+
+    @I.ir_module
+    class Expected:
+        @R.function(pure=False)
+        def transform_params(
+            rank_arg: R.Prim(value="rank"),
+            world_size_arg: R.Prim(value="world_size"),
+            fget_item: R.Callable([R.Prim("int64"), R.Object], R.Object),
+        ):
+            R.func_attr({"num_input": 3})
+
+            rank = T.int64()
+            world_size = T.int64()
+
+            _ = R.assert_op(
+                R.prim_value(16 % world_size == 0),
+                [R.prim_value(16), R.prim_value(world_size)],
+                format=(
+                    "World size must evenly divide A.shape[0] ({}), "
+                    "but received world size of {}."
+                ),
+            )
+            weight_A = fget_item(R.prim_value(0), R.str("weight_A"))
+            weight_A = R.match_cast(weight_A, R.Tensor([16, 64], "float32"))
+            weight_A = R.strided_slice(
+                weight_A,
+                axes=[0],
+                begin=[rank * 16 // world_size],
+                end=[(rank + 1) * 16 // world_size],
+            )
+
+            _ = R.assert_op(
+                R.prim_value(2048 % world_size == 0),
+                [R.prim_value(2048), R.prim_value(world_size)],
+                format=(
+                    "World size must evenly divide B.shape[1] ({}), "
+                    "but received world size of {}."
+                ),
+            )
+            weight_B = fget_item(R.prim_value(1), R.str("weight_B"))
+            weight_B = R.match_cast(weight_B, R.Tensor([1024, 2048], 
"float32"))
+            weight_B = R.strided_slice(
+                weight_B,
+                axes=[1],
+                begin=[rank * 2048 // world_size],
+                end=[(rank + 1) * 2048 // world_size],
+            )
+
+            return (weight_A, weight_B)
+
+    After = relax.transform.LazyGetInput()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_set_output_callback():
+    """fset_output is called for each element of the output tuple
+
+    The call is placed immediately after the corresponding
+    `VarBinding`.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def transform_params(A: R.Tensor([16, 16], "float32"), B: 
R.Tensor([16, 16], "float32")):
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            return (D, C)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def transform_params(
+            A: R.Tensor([16, 16], "float32"),
+            B: R.Tensor([16, 16], "float32"),
+            fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
+        ):
+            C = R.multiply(A, R.const(2, "float32"))
+            fset_output(R.prim_value(1), C)
+            D = R.add(C, B)
+            fset_output(R.prim_value(0), D)
+            return R.tuple()
+
+    After = relax.transform.LazySetOutput()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_set_output_callback_of_param():
+    """fset_output may need to be called for parameters
+
+    A function parameter does not have a `VarBinding`.  If a parameter
+    is returned in the output tuple, the `fset_output` call is
+    generated at the beginning of the function.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def transform_params(A: R.Tensor([16, 16], "float32"), B: 
R.Tensor([16, 16], "float32")):
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            return (D, B)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def transform_params(
+            A: R.Tensor([16, 16], "float32"),
+            B: R.Tensor([16, 16], "float32"),
+            fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
+        ):
+            fset_output(R.prim_value(1), B)
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            fset_output(R.prim_value(0), D)
+            return R.tuple()
+
+    After = relax.transform.LazySetOutput()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_set_output_callback_num_input():
+    """The parameter transformation may have other runtime parameters
+
+    The new `fset_output` parameter is placed after the other runtime
+    parameters, before any model weights.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def transform_params(A: R.Tensor([16, 16], "float32"), B: 
R.Tensor([16, 16], "float32")):
+            R.func_attr({"num_input": 1})
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            return (D, B)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def transform_params(
+            A: R.Tensor([16, 16], "float32"),
+            fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
+            B: R.Tensor([16, 16], "float32"),
+        ):
+            R.func_attr({"num_input": 2})
+            fset_output(R.prim_value(1), B)
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            fset_output(R.prim_value(0), D)
+            return R.tuple()
+
+    After = relax.transform.LazySetOutput()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_set_output_callback_with_duplicate_output():
+    """fset_output may be called more than once for a variable
+
+    A variable may occur multiple times in the output tuple.  The
+    `fset_output` callback should be called once for each tuple
+    element, even if they reuse the same variable.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def transform_params(A: R.Tensor([16, 16], "float32"), B: 
R.Tensor([16, 16], "float32")):
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            return (D, D)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def transform_params(
+            A: R.Tensor([16, 16], "float32"),
+            B: R.Tensor([16, 16], "float32"),
+            fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
+        ):
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            fset_output(R.prim_value(0), D)
+            fset_output(R.prim_value(1), D)
+            return R.tuple()
+
+    After = relax.transform.LazySetOutput()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_set_output_callback_with_inline_const():
+    """fset_output may be called for inline objects
+
+    The return tuple may contain inline leaf nodes, such as
+    `relax.PrimValue` or `relax.Constant`.  A call to `fset_output`
+    must be generated, even though they do not have an associated
+    `relax.VarBinding`.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def transform_params(A: R.Tensor([16, 16], "float32"), B: 
R.Tensor([16, 16], "float32")):
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            return (C, D, R.prim_value(42), R.const(17.5, "float16"))
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def transform_params(
+            A: R.Tensor([16, 16], "float32"),
+            B: R.Tensor([16, 16], "float32"),
+            fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
+        ):
+            C = R.multiply(A, R.const(2, "float32"))
+            fset_output(R.prim_value(0), C)
+            D = R.add(C, B)
+            fset_output(R.prim_value(1), D)
+            fset_output(R.prim_value(2), R.prim_value(42))
+            fset_output(R.prim_value(3), R.const(17.5, "float16"))
+            return R.tuple()
+
+    After = relax.transform.LazySetOutput()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
+def test_set_output_callback_with_non_tuple_output():
+    """Non-tuple outputs produce a single call to fset_output"""
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def transform_params(A: R.Tensor([16, 16], "float32"), B: 
R.Tensor([16, 16], "float32")):
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            return D
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def transform_params(
+            A: R.Tensor([16, 16], "float32"),
+            B: R.Tensor([16, 16], "float32"),
+            fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])),
+        ):
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            fset_output(R.prim_value(0), D)
+            return R.tuple()
+
+    After = relax.transform.LazySetOutput()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to