This is an automated email from the ASF dual-hosted git repository. zhic pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push: new 41835d1 [Relay] Expose FunctionGetAttr to Python (#4905) 41835d1 is described below commit 41835d176d31bc2f3ba1f0ed9e35bdbfd453dc39 Author: Jon Soifer <soif...@gmail.com> AuthorDate: Tue Feb 18 13:44:59 2020 -0800 [Relay] Expose FunctionGetAttr to Python (#4905) * [Relay] Expose FunctionGetAttr to Python * add test Co-authored-by: Jon Soifer <jo...@microsoft.com> --- python/tvm/relay/expr.py | 3 +++ src/relay/ir/expr.cc | 6 ++++++ tests/python/relay/test_ir_nodes.py | 2 ++ 3 files changed, 11 insertions(+) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index e5259fb..39e68b8 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -280,6 +280,9 @@ class Function(BaseFunc): def set_attribute(self, name, ref): return _expr.FunctionSetAttr(self, name, ref) + def get_attribute(self, name): + return _expr.FunctionGetAttr(self, name) + @register_relay_node class Call(ExprWithOp): diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 89395bb..0292a6c 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -360,6 +360,12 @@ TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr") return FunctionSetAttr(func, name, ref); }); +TVM_REGISTER_GLOBAL("relay._expr.FunctionGetAttr") +.set_body_typed( + [](Function func, std::string name) { + return FunctionGetAttr(func, name); +}); + TVM_REGISTER_GLOBAL("relay._make.Any") .set_body_typed([]() { return Any::make(); }); diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index bdda72c..b7d7eb9 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -168,10 +168,12 @@ def test_function(): body = relay.Tuple(tvm.convert([])) type_params = tvm.convert([]) fn = relay.Function(params, body, ret_type, type_params) + fn = fn.set_attribute("test_attribute", tvm.tir.StringImm("value")) assert fn.params == params assert fn.body == body assert fn.type_params == type_params assert fn.span == None + assert fn.get_attribute("test_attribute") == "value" str(fn) check_json_roundtrip(fn)