This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 20730633a4 [Unity] Add support to append relay op attrs in translator
(#14356)
20730633a4 is described below
commit 20730633a4622dc545fcf257d8c0fe868021dc90
Author: Anirudh Sundar Subramaniam <[email protected]>
AuthorDate: Wed Mar 22 03:16:10 2023 +0530
[Unity] Add support to append relay op attrs in translator (#14356)
This PR adds an argument to the relay to relax translator to append the
relay Op Attrs as function attributes to the generated TIR PrimFuncs
This information is really useful in Relax as the absence of this
prevents us from being able schedule efficiently for ops that are
heavily sensitive to the attributes.
For example, the `groups` attribute to `conv2d` op is needed to
differentiate between regular conv2d and depthwise conv2d.
---
python/tvm/relax/testing/relay_translator.py | 14 ++++++++++++++
tests/python/relax/test_relay_translator.py | 12 ++++++++++++
2 files changed, 26 insertions(+)
diff --git a/python/tvm/relax/testing/relay_translator.py
b/python/tvm/relax/testing/relay_translator.py
index c5225e14e7..46fdb7021d 100644
--- a/python/tvm/relax/testing/relay_translator.py
+++ b/python/tvm/relax/testing/relay_translator.py
@@ -38,6 +38,7 @@ def from_relay(
pass_config: Optional[Dict[str, Any]] = None,
disabled_pass: Optional[List[str]] = None,
translate_op_with_tir: Optional[Dict[str, tvm.tir.PrimFunc]] = None,
+ append_op_attrs: bool = False,
) -> IRModule:
"""Convert a Relay function into a Relax program.
@@ -65,6 +66,9 @@ def from_relay(
Dict that maps op names to user-defined PrimFuncs.
Takes relay operator names and forces them to user-defined PrimFuncs
during translation.
+ append_op_attrs: bool
+ Append relay op attrs to generated prim_funcs
+
Returns
-------
mod : tvm.IRModule
@@ -167,6 +171,15 @@ def from_relay(
attrs = node.attrs
out_type = node.checked_type
+ op_attrs_map = {}
+ if append_op_attrs:
+ func_attr_map = {"op_name": op_name}
+ if attrs:
+ for attr in attrs.keys():
+ func_attr_map[attr] = attrs[attr]
+
+ op_attrs_map["op_attrs"] = func_attr_map
+
if translate_op_with_tir and op_name in translate_op_with_tir:
tir_gvar = bb.add_func(translate_op_with_tir[op_name], op_name)
call = relax.call_tir(
@@ -191,6 +204,7 @@ def from_relay(
new_args,
node.checked_type,
primfunc_name_hint=name_hint,
+ primfunc_attrs=op_attrs_map,
)
output_var = var
diff --git a/tests/python/relax/test_relay_translator.py
b/tests/python/relax/test_relay_translator.py
index b4f84027eb..d3cd47b9e6 100644
--- a/tests/python/relax/test_relay_translator.py
+++ b/tests/python/relax/test_relay_translator.py
@@ -296,5 +296,17 @@ def test_translate_tuple_arg():
assert_structural_equal(relax_mod, bb.get())
+def test_append_op_attrs():
+ x = relay.var("x", shape=(10, 16))
+ y = relay.var("y", shape=(10, 16))
+ relay_mod = tvm.IRModule.from_expr(relay.Function([x, y],
relay.concatenate((x, y), axis=-1)))
+ relax_mod_wo_attrs = relay_translator.from_relay(relay_mod["main"],
target="llvm")
+ relax_mod_with_attrs = relay_translator.from_relay(
+ relay_mod["main"], target="llvm", append_op_attrs=True
+ )
+ assert "op_attrs" in relax_mod_with_attrs["concatenate"].attrs
+ assert "op_attrs" not in relax_mod_wo_attrs["concatenate"].attrs
+
+
if __name__ == "__main__":
pytest.main([__file__])