This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ad3dfb4c1c [Bugfix][Executor] fix debug_executor function 
debug_get_output (#16492)
ad3dfb4c1c is described below

commit ad3dfb4c1c750a006f8cc065a5ef2c3dabf0d89f
Author: JiaXing Shi <[email protected]>
AuthorDate: Thu Feb 22 14:37:17 2024 +0800

    [Bugfix][Executor] fix debug_executor function debug_get_output (#16492)
    
    fix debug_executor function debug_get_output
---
 python/tvm/contrib/debugger/debug_executor.py      |  6 ++++--
 .../graph_executor/debug/graph_executor_debug.cc   | 23 ++++++++++++++++++++--
 .../graph_executor/debug/graph_executor_debug.h    | 12 +++++++++++
 3 files changed, 37 insertions(+), 4 deletions(-)

diff --git a/python/tvm/contrib/debugger/debug_executor.py 
b/python/tvm/contrib/debugger/debug_executor.py
index 75932c0d5e..785959ce8d 100644
--- a/python/tvm/contrib/debugger/debug_executor.py
+++ b/python/tvm/contrib/debugger/debug_executor.py
@@ -272,8 +272,10 @@ class GraphModuleDebug(graph_executor.GraphModule):
             node_index = node
         else:
             raise RuntimeError("Require node index or name only.")
-
-        self._debug_get_output(node_index, out)
+        if out:
+            self._debug_get_output(node_index, out)
+            return out
+        return self._debug_get_output(node_index)
 
     # pylint: disable=arguments-differ
     def run(
diff --git a/src/runtime/graph_executor/debug/graph_executor_debug.cc 
b/src/runtime/graph_executor/debug/graph_executor_debug.cc
index 0dbcbff46f..892a13b46b 100644
--- a/src/runtime/graph_executor/debug/graph_executor_debug.cc
+++ b/src/runtime/graph_executor/debug/graph_executor_debug.cc
@@ -197,10 +197,17 @@ PackedFunc GraphExecutorDebug::GetFunction(const String& 
name,
   // return member functions during query.
   if (name == "debug_get_output") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      int args0 = -1;
       if (String::CanConvertFrom(args[0])) {
-        this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]);
+        args0 = this->GetNodeIndex(args[0]);
       } else {
-        this->DebugGetNodeOutput(args[0], args[1]);
+        args0 = args[0];
+      }
+
+      if (args.num_args == 2) {
+        this->DebugGetNodeOutput(args0, args[1]);
+      } else {
+        *rv = this->DebugGetNodeOutput(args0);
       }
     });
   } else if (name == "execute_node") {
@@ -325,6 +332,18 @@ void GraphExecutorDebug::DebugGetNodeOutput(int index, 
DLTensor* data_out) {
   data_entry_[eid].CopyTo(data_out);
 }
 
+NDArray GraphExecutorDebug::DebugGetNodeOutput(int index) {
+  ICHECK_LT(static_cast<size_t>(index), op_execs_.size());
+  uint32_t eid = index;
+
+  for (size_t i = 0; i < op_execs_.size(); ++i) {
+    if (op_execs_[i]) op_execs_[i]();
+    if (static_cast<int>(i) == index) break;
+  }
+
+  return data_entry_[eid];
+}
+
 NDArray GraphExecutorDebug::GetNodeOutput(int node, int out_ind) {
   ICHECK_EQ(node, last_executed_node_);
   ICHECK_LT(entry_id(node, out_ind), data_entry_.size());
diff --git a/src/runtime/graph_executor/debug/graph_executor_debug.h 
b/src/runtime/graph_executor/debug/graph_executor_debug.h
index 7c9d8f2cd1..3820830566 100644
--- a/src/runtime/graph_executor/debug/graph_executor_debug.h
+++ b/src/runtime/graph_executor/debug/graph_executor_debug.h
@@ -122,6 +122,18 @@ class GraphExecutorDebug : public GraphExecutor {
    */
   void DebugGetNodeOutput(int index, DLTensor* data_out);
 
+  /*!
+   * \brief return output of index-th node.
+   *
+   * This method will do a partial run of the graph
+   * from begining up to the index-th node and return output of index-th node.
+   * This is costly operation and suggest to use only for debug porpose.
+   *
+   * \param index: The  index of the node.
+   *
+   */
+  NDArray DebugGetNodeOutput(int index);
+
   /*!
    * \brief Profile execution time of the module.
    *

Reply via email to