This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 059f629ec24 [BYOC] Skip processed functions in FuseOpsByPattern and 
RunCodegen (#16567)
059f629ec24 is described below

commit 059f629ec24a6b36e93794cd41a97b9bed3bdcf5
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Feb 14 05:39:45 2024 -0800

    [BYOC] Skip processed functions in FuseOpsByPattern and RunCodegen (#16567)
---
 src/relax/transform/fuse_ops.cc                          | 10 +++++++++-
 tests/python/relax/test_transform_fuse_ops_by_pattern.py |  9 +++++++++
 2 files changed, 18 insertions(+), 1 deletion(-)

diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 32780f6dd25..0dbee366706 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -1199,7 +1199,8 @@ class CompositeFunctionAnnotator : public ExprMutator {
     auto all_functions = mod->functions;
     for (const auto& entry : all_functions) {
       if (const auto* func = entry.second.as<FunctionNode>()) {
-        if (func->GetAttr<String>(attr::kComposite).defined()) {
+        if (func->GetAttr<String>(attr::kComposite).defined() ||
+            func->GetAttr<String>(attr::kCodegen).defined()) {
           continue;
         }
         auto new_body = VisitExpr(func->body);
@@ -1270,6 +1271,13 @@ IRModule FuseOpsByPattern(const 
tvm::Array<transform::FusionPattern>& patterns,
       if (entry.second->IsInstance<tir::PrimFuncNode>()) {
         continue;
       }
+      const FunctionNode* function = entry.second.as<FunctionNode>();
+      if (function->GetAttr<Integer>(attr::kPrimitive).defined() ||
+          function->GetAttr<String>(attr::kComposite).defined() ||
+          function->GetAttr<String>(attr::kCodegen).defined()) {
+        continue;
+      }
+
       auto map = PatternBasedPartitioner::Run(pattern->name, pattern->pattern,
                                               pattern->annotation_patterns,
                                               
pattern->check.value_or(nullptr), entry.second,
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index bd434864a08..de356fd5480 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -1046,5 +1046,14 @@ def test_intermediate_var_to_var_binding():
     assert "fused_relax_permute_dims_relax_matmul_cublas" in func_names  # add 
is not fused
 
 
+def test_multple_runs():
+    check(
+        Conv2dReLU_composite_annotated,
+        [("dnnl.conv2d_relu", conv2d_relu_pat)],
+        Conv2dReLU_composite_annotated,
+        annotate_codegen=True,
+    )
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to