This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 41835d1 [Relay] Expose FunctionGetAttr to Python (#4905)
41835d1 is described below
commit 41835d176d31bc2f3ba1f0ed9e35bdbfd453dc39
Author: Jon Soifer <[email protected]>
AuthorDate: Tue Feb 18 13:44:59 2020 -0800
[Relay] Expose FunctionGetAttr to Python (#4905)
* [Relay] Expose FunctionGetAttr to Python
* add test
Co-authored-by: Jon Soifer <[email protected]>
---
python/tvm/relay/expr.py | 3 +++
src/relay/ir/expr.cc | 6 ++++++
tests/python/relay/test_ir_nodes.py | 2 ++
3 files changed, 11 insertions(+)
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index e5259fb..39e68b8 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -280,6 +280,9 @@ class Function(BaseFunc):
def set_attribute(self, name, ref):
return _expr.FunctionSetAttr(self, name, ref)
+ def get_attribute(self, name):
+ return _expr.FunctionGetAttr(self, name)
+
@register_relay_node
class Call(ExprWithOp):
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index 89395bb..0292a6c 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -360,6 +360,12 @@ TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr")
return FunctionSetAttr(func, name, ref);
});
+TVM_REGISTER_GLOBAL("relay._expr.FunctionGetAttr")
+.set_body_typed(
+ [](Function func, std::string name) {
+ return FunctionGetAttr(func, name);
+});
+
TVM_REGISTER_GLOBAL("relay._make.Any")
.set_body_typed([]() { return Any::make(); });
diff --git a/tests/python/relay/test_ir_nodes.py
b/tests/python/relay/test_ir_nodes.py
index bdda72c..b7d7eb9 100644
--- a/tests/python/relay/test_ir_nodes.py
+++ b/tests/python/relay/test_ir_nodes.py
@@ -168,10 +168,12 @@ def test_function():
body = relay.Tuple(tvm.convert([]))
type_params = tvm.convert([])
fn = relay.Function(params, body, ret_type, type_params)
+ fn = fn.set_attribute("test_attribute", tvm.tir.StringImm("value"))
assert fn.params == params
assert fn.body == body
assert fn.type_params == type_params
assert fn.span == None
+ assert fn.get_attribute("test_attribute") == "value"
str(fn)
check_json_roundtrip(fn)