This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3e355adf17 [Relax] Fix RemoveUnusedParameters symbolic var promotion 
(#19901)
3e355adf17 is described below

commit 3e355adf176c355e2ad8441c4088dbcee9bfcc96
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Jun 29 17:18:12 2026 -0400

    [Relax] Fix RemoveUnusedParameters symbolic var promotion (#19901)
    
    After the tirx refactor, a relax `PrimType` parameter carries only a
    dtype and no longer binds a value, so it does not *define* the symbolic
    variable it used to provide. When `RemoveUnusedParameters` drops an
    unused tensor parameter whose shape is the sole definition of a free
    symbolic variable, it re-adds that variable through a value-bearing
    `PrimType` (`R.Prim(value=...)`) parameter. Under the stricter tirx
    well-formedness verifier this leaves the variable undefined, so the pass
    emits an ill-formed module.
    
    Promote each such free symbolic variable through a 1-D `ShapeType`
    parameter (`R.Shape([var])`) instead, which actually defines the
    variable, and pass its value at the call site as a `ShapeExpr`. The
    previously xfail-ing `test_replace_symbolic_variables` is updated to the
    new shape-based form and re-enabled as a regression test.
---
 src/relax/transform/remove_unused_parameters.cc              |  9 +++++++--
 .../python/relax/test_transform_remove_unused_parameters.py  | 12 ++++++++----
 2 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/src/relax/transform/remove_unused_parameters.cc 
b/src/relax/transform/remove_unused_parameters.cc
index df986a37d1..1020b1a716 100644
--- a/src/relax/transform/remove_unused_parameters.cc
+++ b/src/relax/transform/remove_unused_parameters.cc
@@ -100,7 +100,10 @@ std::optional<CalleeAnalysis> AnalyzeCallee(Function func) 
{
   }
 
   for (const auto& tir_var : free_tir_vars) {
-    Var relax_var("param_" + tir_var->name_hint, PrimType(tir_var.ty()));
+    // Promote the free symbolic var via a 1-D shape param so the param 
actually
+    // *defines* the var. A PrimType param only carries a dtype and defines no
+    // TIR var, which leaves the var undefined under the strict tirx verifier.
+    Var relax_var("param_" + tir_var->name_hint, ShapeType({tir_var}));
     params.push_back(relax_var);
   }
 
@@ -129,7 +132,9 @@ std::optional<CalleeAnalysis> AnalyzeCallee(Function func) {
       auto tir_binding = InferSymbolicVarMap(old_binding, analyzer);
 
       for (const auto& tir_var : free_tir_vars) {
-        new_args.push_back(PrimExpr(tir_binding.at(tir_var)));
+        // Pass the symbolic var value as a 1-D shape, matching the ShapeType
+        // param that now defines the var in the callee.
+        new_args.push_back(ShapeExpr({tir_binding.at(tir_var)}));
       }
     }
 
diff --git a/tests/python/relax/test_transform_remove_unused_parameters.py 
b/tests/python/relax/test_transform_remove_unused_parameters.py
index 4c05cbdb29..3eaf8270bd 100644
--- a/tests/python/relax/test_transform_remove_unused_parameters.py
+++ b/tests/python/relax/test_transform_remove_unused_parameters.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import pytest
 
 import tvm
 import tvm.testing
@@ -56,7 +55,6 @@ def test_remove_unused_relax_parameter():
     tvm.ir.assert_structural_equal(After, Expected)
 
 
[email protected](reason="value-bearing R.Prim annotations were removed")
 def test_replace_symbolic_variables():
     """If a parameter is only required for its symbolic variables, provide 
them directly
 
@@ -64,6 +62,12 @@ def test_replace_symbolic_variables():
     its shape defines the symbolic variables `m` and `n`.  When
     removing the `R.Tensor` argument, we may need to provide
     additional parameters to define the symbolic variables.
+
+    Value-bearing `R.Prim(value=...)` annotations were removed in the tirx
+    refactor (a `PrimType` carries only a dtype and defines no TIR var, which
+    leaves the var undefined under the strict tirx well-formedness verifier).
+    The replacement is to promote each free symbolic variable through a 1-D
+    `R.Shape` parameter, which actually *defines* the variable.
     """
 
     @I.ir_module
@@ -84,11 +88,11 @@ def test_replace_symbolic_variables():
         def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], 
"float32"):
             m = T.int64()
             n = T.int64()
-            out: R.Tensor([m, n], "float32") = Expected.func(R.prim_value(n), 
R.prim_value(m))
+            out: R.Tensor([m, n], "float32") = Expected.func(R.shape([n]), 
R.shape([m]))
             return out
 
         @R.function(private=True)
-        def func(param_n: R.Prim(value="n"), param_m: R.Prim(value="m")) -> 
R.Tensor(
+        def func(param_n: R.Shape(["n"]), param_m: R.Shape(["m"])) -> R.Tensor(
             ["m", "n"], "float32"
         ):
             m = T.int64()

Reply via email to