This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 3b33caf757 [Unity] Cover all Relax functions in implicit attention
rewrite (#14818)
3b33caf757 is described below
commit 3b33caf757163c2e360e6f6f978e907ba183bcbf
Author: Lite Ye <[email protected]>
AuthorDate: Thu May 11 01:27:23 2023 -0400
[Unity] Cover all Relax functions in implicit attention rewrite (#14818)
* Rewrite all functions in attention op rewriting
* Fix lint
---
python/tvm/relax/backend/contrib/cutlass.py | 16 +++++++++-------
1 file changed, 9 insertions(+), 7 deletions(-)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index d5940ac5e4..19fc2a39ea 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -17,24 +17,24 @@
"""Pattern table for CUTLASS backend"""
import operator
-from typing import Mapping, Sequence
from functools import reduce
+from typing import Mapping, Sequence
import tvm
from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
-from tvm.relax import DataflowVar, Var, transform, Call, PyExprMutator,
expr_functor, Function
-from tvm.relax.transform import PatternCheckContext
+from tvm.relax import Call, DataflowVar, Function, PyExprMutator, Var,
expr_functor, transform
from tvm.relax.dpl import rewrite_call
+from tvm.relax.transform import PatternCheckContext
from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import (
make_attention_pattern,
+ make_attention_rewrite_pattern,
make_fused_bias_activation_pattern,
+ make_layer_norm_pattern,
make_matmul_pattern,
make_residual_block_pattern,
make_stacked_attention_pattern,
- make_layer_norm_pattern,
- make_attention_rewrite_pattern,
)
@@ -435,8 +435,10 @@ def partition_for_cutlass(mod, annotate_codegen=True):
The resulting IRModule, containing partitioned subgraphs to be
compiled by the CUTLASS backend.
"""
- for pattern, rewriter in _REWRITE_PATTERNS:
- mod["main"] = rewrite_call(pattern, rewriter, mod["main"])
+ for func_name, func in mod.functions.items():
+ if isinstance(func, Function):
+ for pattern, rewriter in _REWRITE_PATTERNS:
+ mod[func_name] = rewrite_call(pattern, rewriter, func)
patterns = get_patterns_with_prefix("cutlass")
return tvm.transform.Sequential(
[