This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9dfcb22 [Runtime] Add graph_executor get_input_index API. (#8633)
9dfcb22 is described below
commit 9dfcb2285011203c5db204fe5bb97e681fd11445
Author: Hua Jiang <[email protected]>
AuthorDate: Tue Aug 3 18:32:14 2021 -0700
[Runtime] Add graph_executor get_input_index API. (#8633)
* [Runtime] Add graph_executor get_input_index API.
In graph_executor use case, user can use set_input with
input index to set input parameter, but there is no straight
forward way to get correct index number with input name, here
provide get_input_index API to do such work.
* Update python/tvm/contrib/graph_executor.py
Co-authored-by: Cody Yu <[email protected]>
* Update python/tvm/contrib/graph_executor.py
Co-authored-by: Cody Yu <[email protected]>
* Update src/runtime/graph_executor/graph_executor.cc
Co-authored-by: Cody Yu <[email protected]>
* Update python/tvm/contrib/graph_executor.py
Co-authored-by: Cody Yu <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
---
python/tvm/contrib/graph_executor.py | 16 ++++++++++++++++
src/runtime/graph_executor/graph_executor.cc | 5 +++++
tests/python/relay/test_backend_graph_executor.py | 14 ++++++++++++++
3 files changed, 35 insertions(+)
diff --git a/python/tvm/contrib/graph_executor.py
b/python/tvm/contrib/graph_executor.py
index a4bc859..f9d1b97 100644
--- a/python/tvm/contrib/graph_executor.py
+++ b/python/tvm/contrib/graph_executor.py
@@ -157,6 +157,7 @@ class GraphModule(object):
self._get_output = module["get_output"]
self._get_input = module["get_input"]
self._get_num_outputs = module["get_num_outputs"]
+ self._get_input_index = module["get_input_index"]
self._get_num_inputs = module["get_num_inputs"]
self._load_params = module["load_params"]
self._share_params = module["share_params"]
@@ -242,6 +243,21 @@ class GraphModule(object):
return self._get_input(index)
+ def get_input_index(self, name):
+ """Get inputs index via input name.
+
+ Parameters
+ ----------
+ name : str
+ The input key name
+
+ Returns
+ -------
+ index: int
+ The input index. -1 will be returned if the given input name is
not found.
+ """
+ return self._get_input_index(name)
+
def get_output(self, index, out=None):
"""Get index-th output to out
diff --git a/src/runtime/graph_executor/graph_executor.cc
b/src/runtime/graph_executor/graph_executor.cc
index 7aae12b..bc73a59 100644
--- a/src/runtime/graph_executor/graph_executor.cc
+++ b/src/runtime/graph_executor/graph_executor.cc
@@ -502,6 +502,11 @@ PackedFunc GraphExecutor::GetFunction(const std::string&
name,
dmlc::MemoryStringStream strm(const_cast<std::string*>(¶m_blob));
this->ShareParams(dynamic_cast<const
GraphExecutor&>(*module.operator->()), &strm);
});
+ } else if (name == "get_input_index") {
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
+ *rv = this->GetInputIndex(args[0].operator String());
+ });
} else {
return PackedFunc();
}
diff --git a/tests/python/relay/test_backend_graph_executor.py
b/tests/python/relay/test_backend_graph_executor.py
index 234095f..7beac19 100644
--- a/tests/python/relay/test_backend_graph_executor.py
+++ b/tests/python/relay/test_backend_graph_executor.py
@@ -311,5 +311,19 @@ def test_graph_executor_nested_tuples():
tvm.testing.assert_allclose(out[1][1][1].numpy(), data[3])
+def test_graph_executor_api():
+ dname_0, dname_1 = "data_0", "data_1"
+ data_0, data_1 = [relay.var(c, shape=(1, 1), dtype="float32") for c in
[dname_0, dname_1]]
+ net = relay.add(data_0, data_1)
+ func = relay.Function((data_0, data_1), net)
+
+ lib = relay.build(tvm.IRModule.from_expr(func), "llvm")
+ mod = graph_executor.GraphModule(lib["default"](tvm.cpu(0)))
+
+ assert mod.get_input_index(dname_1) == 1
+ assert mod.get_input_index(dname_0) == 0
+ assert mod.get_input_index("Invalid") == -1
+
+
if __name__ == "__main__":
pytest.main([__file__])