This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 059f629ec24 [BYOC] Skip processed functions in FuseOpsByPattern and
RunCodegen (#16567)
059f629ec24 is described below
commit 059f629ec24a6b36e93794cd41a97b9bed3bdcf5
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Feb 14 05:39:45 2024 -0800
[BYOC] Skip processed functions in FuseOpsByPattern and RunCodegen (#16567)
---
src/relax/transform/fuse_ops.cc | 10 +++++++++-
tests/python/relax/test_transform_fuse_ops_by_pattern.py | 9 +++++++++
2 files changed, 18 insertions(+), 1 deletion(-)
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 32780f6dd25..0dbee366706 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -1199,7 +1199,8 @@ class CompositeFunctionAnnotator : public ExprMutator {
auto all_functions = mod->functions;
for (const auto& entry : all_functions) {
if (const auto* func = entry.second.as<FunctionNode>()) {
- if (func->GetAttr<String>(attr::kComposite).defined()) {
+ if (func->GetAttr<String>(attr::kComposite).defined() ||
+ func->GetAttr<String>(attr::kCodegen).defined()) {
continue;
}
auto new_body = VisitExpr(func->body);
@@ -1270,6 +1271,13 @@ IRModule FuseOpsByPattern(const
tvm::Array<transform::FusionPattern>& patterns,
if (entry.second->IsInstance<tir::PrimFuncNode>()) {
continue;
}
+ const FunctionNode* function = entry.second.as<FunctionNode>();
+ if (function->GetAttr<Integer>(attr::kPrimitive).defined() ||
+ function->GetAttr<String>(attr::kComposite).defined() ||
+ function->GetAttr<String>(attr::kCodegen).defined()) {
+ continue;
+ }
+
auto map = PatternBasedPartitioner::Run(pattern->name, pattern->pattern,
pattern->annotation_patterns,
pattern->check.value_or(nullptr), entry.second,
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index bd434864a08..de356fd5480 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -1046,5 +1046,14 @@ def test_intermediate_var_to_var_binding():
assert "fused_relax_permute_dims_relax_matmul_cublas" in func_names # add
is not fused
+def test_multple_runs():
+ check(
+ Conv2dReLU_composite_annotated,
+ [("dnnl.conv2d_relu", conv2d_relu_pat)],
+ Conv2dReLU_composite_annotated,
+ annotate_codegen=True,
+ )
+
+
if __name__ == "__main__":
pytest.main([__file__])