This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3e355adf17 [Relax] Fix RemoveUnusedParameters symbolic var promotion
(#19901)
3e355adf17 is described below
commit 3e355adf176c355e2ad8441c4088dbcee9bfcc96
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Jun 29 17:18:12 2026 -0400
[Relax] Fix RemoveUnusedParameters symbolic var promotion (#19901)
After the tirx refactor, a relax `PrimType` parameter carries only a
dtype and no longer binds a value, so it does not *define* the symbolic
variable it used to provide. When `RemoveUnusedParameters` drops an
unused tensor parameter whose shape is the sole definition of a free
symbolic variable, it re-adds that variable through a value-bearing
`PrimType` (`R.Prim(value=...)`) parameter. Under the stricter tirx
well-formedness verifier this leaves the variable undefined, so the pass
emits an ill-formed module.
Promote each such free symbolic variable through a 1-D `ShapeType`
parameter (`R.Shape([var])`) instead, which actually defines the
variable, and pass its value at the call site as a `ShapeExpr`. The
previously xfail-ing `test_replace_symbolic_variables` is updated to the
new shape-based form and re-enabled as a regression test.
---
src/relax/transform/remove_unused_parameters.cc | 9 +++++++--
.../python/relax/test_transform_remove_unused_parameters.py | 12 ++++++++----
2 files changed, 15 insertions(+), 6 deletions(-)
diff --git a/src/relax/transform/remove_unused_parameters.cc
b/src/relax/transform/remove_unused_parameters.cc
index df986a37d1..1020b1a716 100644
--- a/src/relax/transform/remove_unused_parameters.cc
+++ b/src/relax/transform/remove_unused_parameters.cc
@@ -100,7 +100,10 @@ std::optional<CalleeAnalysis> AnalyzeCallee(Function func)
{
}
for (const auto& tir_var : free_tir_vars) {
- Var relax_var("param_" + tir_var->name_hint, PrimType(tir_var.ty()));
+ // Promote the free symbolic var via a 1-D shape param so the param
actually
+ // *defines* the var. A PrimType param only carries a dtype and defines no
+ // TIR var, which leaves the var undefined under the strict tirx verifier.
+ Var relax_var("param_" + tir_var->name_hint, ShapeType({tir_var}));
params.push_back(relax_var);
}
@@ -129,7 +132,9 @@ std::optional<CalleeAnalysis> AnalyzeCallee(Function func) {
auto tir_binding = InferSymbolicVarMap(old_binding, analyzer);
for (const auto& tir_var : free_tir_vars) {
- new_args.push_back(PrimExpr(tir_binding.at(tir_var)));
+ // Pass the symbolic var value as a 1-D shape, matching the ShapeType
+ // param that now defines the var in the callee.
+ new_args.push_back(ShapeExpr({tir_binding.at(tir_var)}));
}
}
diff --git a/tests/python/relax/test_transform_remove_unused_parameters.py
b/tests/python/relax/test_transform_remove_unused_parameters.py
index 4c05cbdb29..3eaf8270bd 100644
--- a/tests/python/relax/test_transform_remove_unused_parameters.py
+++ b/tests/python/relax/test_transform_remove_unused_parameters.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-import pytest
import tvm
import tvm.testing
@@ -56,7 +55,6 @@ def test_remove_unused_relax_parameter():
tvm.ir.assert_structural_equal(After, Expected)
[email protected](reason="value-bearing R.Prim annotations were removed")
def test_replace_symbolic_variables():
"""If a parameter is only required for its symbolic variables, provide
them directly
@@ -64,6 +62,12 @@ def test_replace_symbolic_variables():
its shape defines the symbolic variables `m` and `n`. When
removing the `R.Tensor` argument, we may need to provide
additional parameters to define the symbolic variables.
+
+ Value-bearing `R.Prim(value=...)` annotations were removed in the tirx
+ refactor (a `PrimType` carries only a dtype and defines no TIR var, which
+ leaves the var undefined under the strict tirx well-formedness verifier).
+ The replacement is to promote each free symbolic variable through a 1-D
+ `R.Shape` parameter, which actually *defines* the variable.
"""
@I.ir_module
@@ -84,11 +88,11 @@ def test_replace_symbolic_variables():
def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"],
"float32"):
m = T.int64()
n = T.int64()
- out: R.Tensor([m, n], "float32") = Expected.func(R.prim_value(n),
R.prim_value(m))
+ out: R.Tensor([m, n], "float32") = Expected.func(R.shape([n]),
R.shape([m]))
return out
@R.function(private=True)
- def func(param_n: R.Prim(value="n"), param_m: R.Prim(value="m")) ->
R.Tensor(
+ def func(param_n: R.Shape(["n"]), param_m: R.Shape(["m"])) -> R.Tensor(
["m", "n"], "float32"
):
m = T.int64()