This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a692453612 [Relax] Ignore non-relax functions in
relax.transform.RunCodegen (#16586)
a692453612 is described below
commit a6924536127d3bbc222f113a66fafae001044db3
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Feb 17 08:20:29 2024 -0600
[Relax] Ignore non-relax functions in relax.transform.RunCodegen (#16586)
* [Relax] Ignore non-relax functions in relax.transform.RunCodegen
The `relax.transform.RunCodegen` pass replaces calls to relax functions
with the `"Codegen"` attribute with calls into a compiled module.
Prior to this commit, while calls to relax functions without the
`"Codegen"` attribute were ignored, calls to non-relax functions would
raise an error.
This commit updates `relax.transform.RunCodegen` to also ignore calls
to non-relax functions.
* Remove debug changes
---
src/relax/transform/run_codegen.cc | 4 ++--
tests/python/relax/test_transform_codegen_pass.py | 29 +++++++++++++++++++++++
2 files changed, 31 insertions(+), 2 deletions(-)
diff --git a/src/relax/transform/run_codegen.cc
b/src/relax/transform/run_codegen.cc
index c385ae46ef..fe0e73d99e 100644
--- a/src/relax/transform/run_codegen.cc
+++ b/src/relax/transform/run_codegen.cc
@@ -116,9 +116,9 @@ class CodeGenRunner : ExprMutator {
auto ret_sinfo = GetStructInfo(call);
if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) {
return create_call_dps_packed(it->second, ret_sinfo);
- } else {
+ } else if (auto opt_func =
builder_->GetContextIRModule()->Lookup(gvar).as<Function>()) {
// TODO(@sunggg): Is there any better way to get this func?
- Function func =
Downcast<Function>(builder_->GetContextIRModule()->Lookup(gvar));
+ Function func = opt_func.value();
Expr new_func = VisitExpr(func);
if (new_func->IsInstance<ExternFuncNode>()) {
diff --git a/tests/python/relax/test_transform_codegen_pass.py
b/tests/python/relax/test_transform_codegen_pass.py
index cc8f390b96..560bd3bc0b 100644
--- a/tests/python/relax/test_transform_codegen_pass.py
+++ b/tests/python/relax/test_transform_codegen_pass.py
@@ -352,6 +352,35 @@ def test_dynamic_shape():
tvm.ir.assert_structural_equal(after["main"], Expected["main"])
+def test_no_op_for_call_to_tir():
+ """Calls to PrimFunc are ignored
+
+ RunCodegen should only update calls to Relax functions annotated
+ with the `"Codegen"` attribute. Calls to any other function type
+ should be ignored.
+
+ This is a regression test. Previous implementations performed an
+ unconditional cast from `tvm::BaseFunc` to `tvm::relax::Function`,
+ which produced an error.
+ """
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor):
+ R.func_attr({"relax.force_pure": True})
+ _ = Before.shape_func(x)
+ return x
+
+ @T.prim_func(private=True)
+ def shape_func(H: T.Buffer(T.int64(4), "int64")):
+ H[T.int64(0)] = H[T.int64(0)] + T.int64(1)
+
+ Expected = Before
+ After = relax.transform.RunCodegen()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
# TODO(@sunggg): test with more complex patterns (e.g., multiple annots,
mixed codegens, different ops, const binding)
if __name__ == "__main__":