This is an automated email from the ASF dual-hosted git repository. marisa pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push: new bfa4eae Add 'get_num_inputs' to GraphRuntime (#6118) bfa4eae is described below commit bfa4eae1dcac7f2493e543823e51eb420b0f8b2c Author: Alexander Booth <adbub...@gmail.com> AuthorDate: Fri Jul 24 07:22:39 2020 -0700 Add 'get_num_inputs' to GraphRuntime (#6118) --- python/tvm/contrib/graph_runtime.py | 11 +++++++++++ src/runtime/graph/graph_runtime.cc | 9 +++++++++ src/runtime/graph/graph_runtime.h | 6 ++++++ 3 files changed, 26 insertions(+) diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py index ec102f5..326eccb 100644 --- a/python/tvm/contrib/graph_runtime.py +++ b/python/tvm/contrib/graph_runtime.py @@ -133,6 +133,7 @@ class GraphModule(object): self._get_output = module["get_output"] self._get_input = module["get_input"] self._get_num_outputs = module["get_num_outputs"] + self._get_num_inputs = module["get_num_inputs"] self._load_params = module["load_params"] self._share_params = module["share_params"] @@ -187,6 +188,16 @@ class GraphModule(object): """ return self._get_num_outputs() + def get_num_inputs(self): + """Get the number of inputs to the graph + + Returns + ------- + count : int + The number of inputs. + """ + return self._get_num_inputs() + def get_input(self, index, out=None): """Get index-th input to out diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index e984861..18245ba 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -135,6 +135,12 @@ void GraphRuntime::SetInputZeroCopy(int index, DLTensor* data_ref) { */ int GraphRuntime::NumOutputs() const { return outputs_.size(); } /*! + * \brief Get the number of inputs + * + * \return The number of inputs to the graph. + */ +int GraphRuntime::NumInputs() const { return input_nodes_.size(); } +/*! * \brief Return NDArray for given input index. * \param index The input index. * @@ -433,6 +439,9 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name, } else if (name == "get_num_outputs") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); }); + } else if (name == "get_num_inputs") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); }); } else if (name == "run") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); }); } else if (name == "load_params") { diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index d0c9822..dcef1e4 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -125,6 +125,12 @@ class TVM_DLL GraphRuntime : public ModuleNode { */ int NumOutputs() const; /*! + * \brief Get the number of inputs + * + * \return The number of inputs to the graph. + */ + int NumInputs() const; + /*! * \brief Return NDArray for given input index. * \param index The input index. *