This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 41835d1  [Relay] Expose FunctionGetAttr to Python (#4905)
41835d1 is described below

commit 41835d176d31bc2f3ba1f0ed9e35bdbfd453dc39
Author: Jon Soifer <soif...@gmail.com>
AuthorDate: Tue Feb 18 13:44:59 2020 -0800

    [Relay] Expose FunctionGetAttr to Python (#4905)
    
    * [Relay] Expose FunctionGetAttr to Python
    
    * add test
    
    Co-authored-by: Jon Soifer <jo...@microsoft.com>
---
 python/tvm/relay/expr.py            | 3 +++
 src/relay/ir/expr.cc                | 6 ++++++
 tests/python/relay/test_ir_nodes.py | 2 ++
 3 files changed, 11 insertions(+)

diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index e5259fb..39e68b8 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -280,6 +280,9 @@ class Function(BaseFunc):
     def set_attribute(self, name, ref):
         return _expr.FunctionSetAttr(self, name, ref)
 
+    def get_attribute(self, name):
+        return _expr.FunctionGetAttr(self, name)
+
 
 @register_relay_node
 class Call(ExprWithOp):
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index 89395bb..0292a6c 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -360,6 +360,12 @@ TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr")
     return FunctionSetAttr(func, name, ref);
 });
 
+TVM_REGISTER_GLOBAL("relay._expr.FunctionGetAttr")
+.set_body_typed(
+  [](Function func, std::string name) {
+    return FunctionGetAttr(func, name);
+});
+
 TVM_REGISTER_GLOBAL("relay._make.Any")
 .set_body_typed([]() { return Any::make(); });
 
diff --git a/tests/python/relay/test_ir_nodes.py 
b/tests/python/relay/test_ir_nodes.py
index bdda72c..b7d7eb9 100644
--- a/tests/python/relay/test_ir_nodes.py
+++ b/tests/python/relay/test_ir_nodes.py
@@ -168,10 +168,12 @@ def test_function():
     body = relay.Tuple(tvm.convert([]))
     type_params = tvm.convert([])
     fn = relay.Function(params, body, ret_type, type_params)
+    fn = fn.set_attribute("test_attribute", tvm.tir.StringImm("value"))
     assert fn.params == params
     assert fn.body == body
     assert fn.type_params == type_params
     assert fn.span == None
+    assert fn.get_attribute("test_attribute") == "value"
     str(fn)
     check_json_roundtrip(fn)
 

Reply via email to