This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f74762743 [Relax] Provide well-formed output in 
`transform.LazyGetInput` (#16841)
6f74762743 is described below

commit 6f747627431e1d2863c02a58f0e985a0f7c49298
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Apr 3 21:37:15 2024 -0500

    [Relax] Provide well-formed output in `transform.LazyGetInput` (#16841)
    
    Prior to this commit, symbolic variables inferred from the parameters
    were retained in the output function's `ret_struct_info`.  This is
    ill-formed, as the parameters from which these symbolic variables are
    inferred are no longer part of the function signature.
    
    This commit updates `LazyGetInput` to use `EraseToWellDefined` to
    remove any symbolic variables from `ret_struct_info` that cannot be
    inferred from the remaining arguments.
---
 src/relax/transform/lazy_transform_params.cc       | 14 +++++++++
 .../relax/test_transform_lazy_transform_params.py  | 34 ++++++++++++++++++++++
 2 files changed, 48 insertions(+)

diff --git a/src/relax/transform/lazy_transform_params.cc 
b/src/relax/transform/lazy_transform_params.cc
index 21608af7db..37827fbe0e 100644
--- a/src/relax/transform/lazy_transform_params.cc
+++ b/src/relax/transform/lazy_transform_params.cc
@@ -71,8 +71,22 @@ class LazyInputMutator : public ExprMutator {
     Array<Var> new_params(func->params.begin(), func->params.begin() + 
num_input_params);
     new_params.push_back(fget_param);
 
+    auto array_externally_visible_vars =
+        
DefinableTIRVarsInStructInfo(TupleStructInfo(new_params.Map(GetStructInfo)));
+    std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> 
externally_visible_vars(
+        array_externally_visible_vars.begin(), 
array_externally_visible_vars.end());
+    StructInfo new_ret_struct_info =
+        EraseToWellDefined(func->ret_struct_info, [&](const tir::Var& var) -> 
Optional<PrimExpr> {
+          if (externally_visible_vars.count(var)) {
+            return var;
+          } else {
+            return NullOpt;
+          }
+        });
+
     auto node = GetRef<Function>(func);
     node.CopyOnWrite()->params = new_params;
+    node.CopyOnWrite()->ret_struct_info = new_ret_struct_info;
     node = WithAttr(node, attr::kNumInput, Integer(num_input_params + 1));
 
     plan_ = FunctionPlan{std::move(param_lookup), fget_param};
diff --git a/tests/python/relax/test_transform_lazy_transform_params.py 
b/tests/python/relax/test_transform_lazy_transform_params.py
index 833cbd460c..040aea2890 100644
--- a/tests/python/relax/test_transform_lazy_transform_params.py
+++ b/tests/python/relax/test_transform_lazy_transform_params.py
@@ -951,6 +951,40 @@ def test_get_item_callback_num_attrs():
     tvm.ir.assert_structural_equal(After, Expected)
 
 
+def test_get_item_callback_dynamic_shape():
+    @I.ir_module
+    class Before:
+        @R.function
+        def transform_params(
+            A: R.Tensor(["m", "n"], "float32"), B: R.Tensor(["m", "n"], 
"float32")
+        ) -> R.Tuple(R.Tensor(["m", "n"], "float32"), R.Tensor(["m", "n"], 
"float32")):
+            C = R.multiply(A, R.const(2, "float32"))
+            D = R.add(C, B)
+            return (D, B)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def transform_params(
+            fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object)
+        ) -> R.Tuple(R.Tensor(ndim=2, dtype="float32"), R.Tensor(ndim=2, 
dtype="float32")):
+            R.func_attr({"num_input": 1})
+            m = T.int64()
+            n = T.int64()
+
+            A = fget_param(R.prim_value(0), R.str("A"))
+            A = R.match_cast(A, R.Tensor([m, n], "float32"))
+            C = R.multiply(A, R.const(2, "float32"))
+
+            B = fget_param(R.prim_value(1), R.str("B"))
+            B = R.match_cast(B, R.Tensor([m, n], "float32"))
+            D = R.add(C, B)
+            return (D, B)
+
+    After = relax.transform.LazyGetInput()(Before)
+    tvm.ir.assert_structural_equal(After, Expected)
+
+
 def test_set_output_callback():
     """fset_output is called for each element of the output tuple
 

Reply via email to