This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a692453612 [Relax] Ignore non-relax functions in 
relax.transform.RunCodegen (#16586)
a692453612 is described below

commit a6924536127d3bbc222f113a66fafae001044db3
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sat Feb 17 08:20:29 2024 -0600

    [Relax] Ignore non-relax functions in relax.transform.RunCodegen (#16586)
    
    * [Relax] Ignore non-relax functions in relax.transform.RunCodegen
    
    The `relax.transform.RunCodegen` pass replaces calls to relax functions
    with the `"Codegen"` attribute with calls into a compiled module.
    Prior to this commit, while calls to relax functions without the
    `"Codegen"` attribute were ignored, calls to non-relax functions would
    raise an error.
    
    This commit updates `relax.transform.RunCodegen` to also ignore calls
    to non-relax functions.
    
    * Remove debug changes
---
 src/relax/transform/run_codegen.cc                |  4 ++--
 tests/python/relax/test_transform_codegen_pass.py | 29 +++++++++++++++++++++++
 2 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/src/relax/transform/run_codegen.cc 
b/src/relax/transform/run_codegen.cc
index c385ae46ef..fe0e73d99e 100644
--- a/src/relax/transform/run_codegen.cc
+++ b/src/relax/transform/run_codegen.cc
@@ -116,9 +116,9 @@ class CodeGenRunner : ExprMutator {
       auto ret_sinfo = GetStructInfo(call);
       if (auto it = extern_funcs_.find(gvar_node); it != extern_funcs_.end()) {
         return create_call_dps_packed(it->second, ret_sinfo);
-      } else {
+      } else if (auto opt_func = 
builder_->GetContextIRModule()->Lookup(gvar).as<Function>()) {
         // TODO(@sunggg): Is there any better way to get this func?
-        Function func = 
Downcast<Function>(builder_->GetContextIRModule()->Lookup(gvar));
+        Function func = opt_func.value();
         Expr new_func = VisitExpr(func);
 
         if (new_func->IsInstance<ExternFuncNode>()) {
diff --git a/tests/python/relax/test_transform_codegen_pass.py 
b/tests/python/relax/test_transform_codegen_pass.py
index cc8f390b96..560bd3bc0b 100644
--- a/tests/python/relax/test_transform_codegen_pass.py
+++ b/tests/python/relax/test_transform_codegen_pass.py
@@ -352,6 +352,35 @@ def test_dynamic_shape():
     tvm.ir.assert_structural_equal(after["main"], Expected["main"])
 
 
+def test_no_op_for_call_to_tir():
+    """Calls to PrimFunc are ignored
+
+    RunCodegen should only update calls to Relax functions annotated
+    with the `"Codegen"` attribute.  Calls to any other function type
+    should be ignored.
+
+    This is a regression test.  Previous implementations performed an
+    unconditional cast from `tvm::BaseFunc` to `tvm::relax::Function`,
+    which produced an error.
+    """
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor):
+            R.func_attr({"relax.force_pure": True})
+            _ = Before.shape_func(x)
+            return x
+
+        @T.prim_func(private=True)
+        def shape_func(H: T.Buffer(T.int64(4), "int64")):
+            H[T.int64(0)] = H[T.int64(0)] + T.int64(1)
+
+    Expected = Before
+    After = relax.transform.RunCodegen()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 # TODO(@sunggg):  test with more complex patterns (e.g., multiple annots, 
mixed codegens, different ops, const binding)
 
 if __name__ == "__main__":

Reply via email to