This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 2d7c673fcc [Unity] Add bind_constants option to FuseOpsByPattern
(#14151)
2d7c673fcc is described below
commit 2d7c673fccb89ed11355726fd06a29c69474b4fd
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Mar 1 12:55:00 2023 -0800
[Unity] Add bind_constants option to FuseOpsByPattern (#14151)
* [Unity] Add lift_constatns option to FuseOpsByPattern
* lift_constants -> bind_constants
---
include/tvm/relax/transform.h | 4 +-
python/tvm/relax/backend/contrib/cutlass.py | 4 +-
python/tvm/relax/transform/transform.py | 7 ++-
src/relax/transform/fuse_ops.cc | 12 +++--
tests/python/relax/test_codegen_cutlass.py | 2 +-
.../relax/test_transform_fuse_ops_by_pattern.py | 61 ++++++++++++++++++++--
6 files changed, 77 insertions(+), 13 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 907f0bf8cd..715c8e56ff 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -225,6 +225,7 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
* of priority in which they are matched. Higher-priority patterns should come
earlier in the list.
* \param checks The callback functions with type (Map<DFPattern, Expr>, Expr)
-> bool. It takes a
* match result and returns a boolean value to indicate whether the match
result is accepted.
+ * \param bind_constants Whether or not to keep bound constants of the grouped
function.
* \param annotate_codegen If true, wrap each created composite function with
another function,
* whose body consists only of a call to the composite function, and annotate
the outer function
* with kCodegen and kGlobalSymbol attributes. The kCodegen attribute is set
as the prefix of the
@@ -235,7 +236,8 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
*/
TVM_DLL Pass FuseOpsByPattern(const tvm::Array<runtime::String>& pattern_names,
const tvm::Array<DFPattern>& patterns,
- const tvm::Array<PackedFunc>& checks, bool
annotate_codegen = false);
+ const tvm::Array<PackedFunc>& checks, bool
bind_constants = true,
+ bool annotate_codegen = false);
/*!
* \brief Group one or multiple composite functions created by
FuseOpsByPattern into a new
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 51684abb06..e98194ca21 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -117,4 +117,6 @@ def partition_for_cutlass(mod):
"""
cutlass_patterns = get_patterns_with_prefix("cutlass")
- return transform.FuseOpsByPattern(cutlass_patterns,
annotate_codegen=True)(mod)
+ return transform.FuseOpsByPattern(cutlass_patterns, bind_constants=True,
annotate_codegen=True)(
+ mod
+ )
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 97daae4941..a33ad63093 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -283,7 +283,7 @@ def FuseTIR() -> tvm.ir.transform.Pass:
def FuseOpsByPattern(
- patterns: List[Tuple], annotate_codegen: bool = False
+ patterns: List[Tuple], bind_constants: bool = True, annotate_codegen: bool
= False
) -> tvm.ir.transform.Pass:
"""Apply pattern matching to each function in the given module, and group
matched expressions
into a new function.
@@ -302,6 +302,9 @@ def FuseOpsByPattern(
The string is the name of the corresponding pattern. It becomes the
value of the kComposite
attribute of a fused function after a successful matching.
+ bind_constants : bool
+ Whether or not to keep bound constants in the grouped function.
+
annotate_codegen : bool
If True, wrap each created composite function with another function,
whose body consists
only of a call to the composite function, and annotate the outer
function with "Codegen"
@@ -332,7 +335,7 @@ def FuseOpsByPattern(
else:
raise ValueError("Invalid pattern: {}".format(tup))
return _ffi_api.FuseOpsByPattern(
- pattern_names, df_patterns, checks, annotate_codegen
+ pattern_names, df_patterns, checks, bind_constants, annotate_codegen
) # type: ignore
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 60b2c77e49..72427a8a0e 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -574,6 +574,8 @@ class OperatorFusor : public ExprMutator {
* \param mod The IRModule to be transformed
* \param graph The indexed-forward graph of the input IRModule
* \param groups The grouped result of the group partition on the input
indexed-forward graph.
+ * \param lift_constant Whether or not to lift bound constants to parameters
of the grouped
+ * function.
*/
OperatorFusor(IRModule mod, const IndexedForwardGraph& graph, const
std::vector<Group*>& groups,
bool lift_constant = true)
@@ -1052,7 +1054,7 @@ class CompositeFunctionAnnotator : public ExprMutator {
IRModule FuseOpsByPattern(const tvm::Array<String>& pattern_names,
const tvm::Array<DFPattern>& patterns,
const tvm::Array<runtime::PackedFunc>& checks,
IRModule mod,
- bool annotate_codegen) {
+ bool bind_constants, bool annotate_codegen) {
support::Arena arena;
for (size_t i = 0; i < pattern_names.size(); ++i) {
OperatorFusor::GroupMap group_map;
@@ -1064,7 +1066,7 @@ IRModule FuseOpsByPattern(const tvm::Array<String>&
pattern_names,
entry.second, &arena);
group_map.insert(map.begin(), map.end());
}
- mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ false);
+ mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/
!bind_constants);
}
if (annotate_codegen) {
return CompositeFunctionAnnotator(mod).Run();
@@ -1091,10 +1093,12 @@
TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);
Pass FuseOpsByPattern(const tvm::Array<String>& pattern_names,
const tvm::Array<DFPattern>& patterns,
- const tvm::Array<runtime::PackedFunc>& checks, bool
annotate_codegen) {
+ const tvm::Array<runtime::PackedFunc>& checks, bool
bind_constants,
+ bool annotate_codegen) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule m, PassContext pc) {
- return relax::FuseOpsByPattern(pattern_names, patterns, checks, m,
annotate_codegen);
+ return relax::FuseOpsByPattern(pattern_names, patterns, checks, m,
bind_constants,
+ annotate_codegen);
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index bce8c5a84f..6eb476496c 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -100,7 +100,7 @@ def get_result_with_relax_cutlass_offload(mod, *args):
seq = tvm.transform.Sequential(
[
- relax.transform.FuseOpsByPattern(patterns, annotate_codegen=True),
+ relax.transform.FuseOpsByPattern(patterns, bind_constants=False,
annotate_codegen=True),
relax.transform.RunCodegen({"cutlass": {"sm": 80,
"find_first_valid": True}}),
]
)
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index a9e76feb6c..b4b5591c01 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -389,8 +389,8 @@ conv2d_pat =
make_fused_bias_activation_pattern("relax.nn.conv2d", activation=No
conv2d_relu_pat = make_fused_bias_activation_pattern("relax.nn.conv2d",
activation="relax.nn.relu")
-def check(mod, patterns, expected, annoatate_codegen=False):
- partitioned = relax.transform.FuseOpsByPattern(patterns,
annoatate_codegen)(mod)
+def check(mod, patterns, expected, bind_constants=True,
annoatate_codegen=False):
+ partitioned = relax.transform.FuseOpsByPattern(patterns, bind_constants,
annoatate_codegen)(mod)
tvm.ir.assert_structural_equal(partitioned, expected)
@@ -424,7 +424,9 @@ def test_cyclic_dependency():
add_pat = is_op("relax.add")(relu_pat, wildcard())
with pytest.raises(tvm.error.TVMError) as err:
- relax.transform.FuseOpsByPattern([("compiler_A.conv2d_relu_add",
add_pat)])(Branch)
+ relax.transform.FuseOpsByPattern(
+ [("compiler_A.conv2d_relu_add", add_pat)], bind_constants=True
+ )(Branch)
assert "A cyclic dependency detected" in str(err.value)
@@ -434,7 +436,9 @@ def test_bind_params():
mod = tvm.transform.Sequential(
[
relax.transform.BindParams("main", {"weight1": weight_np}),
- relax.transform.FuseOpsByPattern([("dnnl.conv2d_relu",
conv2d_relu_pat)]),
+ relax.transform.FuseOpsByPattern(
+ [("dnnl.conv2d_relu", conv2d_relu_pat)], bind_constants=True
+ ),
]
)(Conv2dReLU)
@@ -589,5 +593,54 @@ def test_check_pattern():
check(Conv2dx2, [("cutlass.conv2d", pat, pred)], Conv2dx2) # expect no
partitioning
+def test_bind_constants():
+ weight = np.random.randn(64, 64, 3, 3).astype("float32")
+
+ @I.ir_module
+ class Conv2dWithConstantWeight:
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), "float32"),
+ weight1: R.Tensor((64, 64, 3, 3), "float32"),
+ ):
+ with R.dataflow():
+ conv1 = R.nn.conv2d(data, R.const(weight), padding=(1, 1))
+ R.output(conv1)
+ return conv1
+
+ @I.ir_module
+ class Conv2dWithConstantWeight_partitioned:
+ @R.function
+ def fused_relax_nn_conv2d(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ param_0: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ R.func_attr({"Composite": "cutlass.conv2d", "Primitive": 1})
+ with R.dataflow():
+ gv = R.nn.conv2d(data, param_0, padding=(1, 1))
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ data: R.Tensor((1, 64, 56, 56), dtype="float32"),
+ weight1: R.Tensor((64, 64, 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 64, 56, 56), dtype="float32"):
+ with R.dataflow():
+ gv: R.Tensor((1, 64, 56, 56), dtype="float32") =
fused_relax_nn_conv2d(
+ data, R.const(weight)
+ )
+ R.output(gv)
+ return gv
+
+ pat = make_fused_bias_activation_pattern("relax.nn.conv2d",
with_bias=False, activation=None)
+ check(
+ Conv2dWithConstantWeight,
+ [("cutlass.conv2d", pat)],
+ Conv2dWithConstantWeight_partitioned,
+ bind_constants=False,
+ )
+
+
if __name__ == "__main__":
pytest.main([__file__])