This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6f74762743 [Relax] Provide well-formed output in
`transform.LazyGetInput` (#16841)
6f74762743 is described below
commit 6f747627431e1d2863c02a58f0e985a0f7c49298
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Apr 3 21:37:15 2024 -0500
[Relax] Provide well-formed output in `transform.LazyGetInput` (#16841)
Prior to this commit, symbolic variables inferred from the parameters
were retained in the output function's `ret_struct_info`. This is
ill-formed, as the parameters from which these symbolic variables are
inferred are no longer part of the function signature.
This commit updates `LazyGetInput` to use `EraseToWellDefined` to
remove any symbolic variables from `ret_struct_info` that cannot be
inferred from the remaining arguments.
---
src/relax/transform/lazy_transform_params.cc | 14 +++++++++
.../relax/test_transform_lazy_transform_params.py | 34 ++++++++++++++++++++++
2 files changed, 48 insertions(+)
diff --git a/src/relax/transform/lazy_transform_params.cc
b/src/relax/transform/lazy_transform_params.cc
index 21608af7db..37827fbe0e 100644
--- a/src/relax/transform/lazy_transform_params.cc
+++ b/src/relax/transform/lazy_transform_params.cc
@@ -71,8 +71,22 @@ class LazyInputMutator : public ExprMutator {
Array<Var> new_params(func->params.begin(), func->params.begin() +
num_input_params);
new_params.push_back(fget_param);
+ auto array_externally_visible_vars =
+
DefinableTIRVarsInStructInfo(TupleStructInfo(new_params.Map(GetStructInfo)));
+ std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual>
externally_visible_vars(
+ array_externally_visible_vars.begin(),
array_externally_visible_vars.end());
+ StructInfo new_ret_struct_info =
+ EraseToWellDefined(func->ret_struct_info, [&](const tir::Var& var) ->
Optional<PrimExpr> {
+ if (externally_visible_vars.count(var)) {
+ return var;
+ } else {
+ return NullOpt;
+ }
+ });
+
auto node = GetRef<Function>(func);
node.CopyOnWrite()->params = new_params;
+ node.CopyOnWrite()->ret_struct_info = new_ret_struct_info;
node = WithAttr(node, attr::kNumInput, Integer(num_input_params + 1));
plan_ = FunctionPlan{std::move(param_lookup), fget_param};
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py
b/tests/python/relax/test_transform_lazy_transform_params.py
index 833cbd460c..040aea2890 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -951,6 +951,40 @@ def test_get_item_callback_num_attrs():
tvm.ir.assert_structural_equal(After, Expected)
+def test_get_item_callback_dynamic_shape():
+ @I.ir_module
+ class Before:
+ @R.function
+ def transform_params(
+ A: R.Tensor(["m", "n"], "float32"), B: R.Tensor(["m", "n"],
"float32")
+ ) -> R.Tuple(R.Tensor(["m", "n"], "float32"), R.Tensor(["m", "n"],
"float32")):
+ C = R.multiply(A, R.const(2, "float32"))
+ D = R.add(C, B)
+ return (D, B)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def transform_params(
+ fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object)
+ ) -> R.Tuple(R.Tensor(ndim=2, dtype="float32"), R.Tensor(ndim=2,
dtype="float32")):
+ R.func_attr({"num_input": 1})
+ m = T.int64()
+ n = T.int64()
+
+ A = fget_param(R.prim_value(0), R.str("A"))
+ A = R.match_cast(A, R.Tensor([m, n], "float32"))
+ C = R.multiply(A, R.const(2, "float32"))
+
+ B = fget_param(R.prim_value(1), R.str("B"))
+ B = R.match_cast(B, R.Tensor([m, n], "float32"))
+ D = R.add(C, B)
+ return (D, B)
+
+ After = relax.transform.LazyGetInput()(Before)
+ tvm.ir.assert_structural_equal(After, Expected)
+
+
def test_set_output_callback():
"""fset_output is called for each element of the output tuple