This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ad3dfb4c1c [Bugfix][Executor] fix debug_executor function
debug_get_output (#16492)
ad3dfb4c1c is described below
commit ad3dfb4c1c750a006f8cc065a5ef2c3dabf0d89f
Author: JiaXing Shi <[email protected]>
AuthorDate: Thu Feb 22 14:37:17 2024 +0800
[Bugfix][Executor] fix debug_executor function debug_get_output (#16492)
fix debug_executor function debug_get_output
---
python/tvm/contrib/debugger/debug_executor.py | 6 ++++--
.../graph_executor/debug/graph_executor_debug.cc | 23 ++++++++++++++++++++--
.../graph_executor/debug/graph_executor_debug.h | 12 +++++++++++
3 files changed, 37 insertions(+), 4 deletions(-)
diff --git a/python/tvm/contrib/debugger/debug_executor.py
b/python/tvm/contrib/debugger/debug_executor.py
index 75932c0d5e..785959ce8d 100644
--- a/python/tvm/contrib/debugger/debug_executor.py
+++ b/python/tvm/contrib/debugger/debug_executor.py
@@ -272,8 +272,10 @@ class GraphModuleDebug(graph_executor.GraphModule):
node_index = node
else:
raise RuntimeError("Require node index or name only.")
-
- self._debug_get_output(node_index, out)
+ if out:
+ self._debug_get_output(node_index, out)
+ return out
+ return self._debug_get_output(node_index)
# pylint: disable=arguments-differ
def run(
diff --git a/src/runtime/graph_executor/debug/graph_executor_debug.cc
b/src/runtime/graph_executor/debug/graph_executor_debug.cc
index 0dbcbff46f..892a13b46b 100644
--- a/src/runtime/graph_executor/debug/graph_executor_debug.cc
+++ b/src/runtime/graph_executor/debug/graph_executor_debug.cc
@@ -197,10 +197,17 @@ PackedFunc GraphExecutorDebug::GetFunction(const String&
name,
// return member functions during query.
if (name == "debug_get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ int args0 = -1;
if (String::CanConvertFrom(args[0])) {
- this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]);
+ args0 = this->GetNodeIndex(args[0]);
} else {
- this->DebugGetNodeOutput(args[0], args[1]);
+ args0 = args[0];
+ }
+
+ if (args.num_args == 2) {
+ this->DebugGetNodeOutput(args0, args[1]);
+ } else {
+ *rv = this->DebugGetNodeOutput(args0);
}
});
} else if (name == "execute_node") {
@@ -325,6 +332,18 @@ void GraphExecutorDebug::DebugGetNodeOutput(int index,
DLTensor* data_out) {
data_entry_[eid].CopyTo(data_out);
}
+NDArray GraphExecutorDebug::DebugGetNodeOutput(int index) {
+ ICHECK_LT(static_cast<size_t>(index), op_execs_.size());
+ uint32_t eid = index;
+
+ for (size_t i = 0; i < op_execs_.size(); ++i) {
+ if (op_execs_[i]) op_execs_[i]();
+ if (static_cast<int>(i) == index) break;
+ }
+
+ return data_entry_[eid];
+}
+
NDArray GraphExecutorDebug::GetNodeOutput(int node, int out_ind) {
ICHECK_EQ(node, last_executed_node_);
ICHECK_LT(entry_id(node, out_ind), data_entry_.size());
diff --git a/src/runtime/graph_executor/debug/graph_executor_debug.h
b/src/runtime/graph_executor/debug/graph_executor_debug.h
index 7c9d8f2cd1..3820830566 100644
--- a/src/runtime/graph_executor/debug/graph_executor_debug.h
+++ b/src/runtime/graph_executor/debug/graph_executor_debug.h
@@ -122,6 +122,18 @@ class GraphExecutorDebug : public GraphExecutor {
*/
void DebugGetNodeOutput(int index, DLTensor* data_out);
+ /*!
+ * \brief return output of index-th node.
+ *
+ * This method will do a partial run of the graph
+ * from begining up to the index-th node and return output of index-th node.
+ * This is costly operation and suggest to use only for debug porpose.
+ *
+ * \param index: The index of the node.
+ *
+ */
+ NDArray DebugGetNodeOutput(int index);
+
/*!
* \brief Profile execution time of the module.
*