This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 482e341  Fix JSON graph dumping. (#5591)
482e341 is described below

commit 482e34107054a08324b29d4078dfaecbe3c68430
Author: Andrew Reusch <[email protected]>
AuthorDate: Wed May 13 18:20:26 2020 -0700

    Fix JSON graph dumping. (#5591)
    
    * Previously this function placed a JSON-escaped string containing
       the JSON-encoded graph.
---
 python/tvm/contrib/debugger/debug_result.py       |  8 ++++----
 tests/python/unittest/test_runtime_graph_debug.py | 13 +++++++++++--
 2 files changed, 15 insertions(+), 6 deletions(-)

diff --git a/python/tvm/contrib/debugger/debug_result.py 
b/python/tvm/contrib/debugger/debug_result.py
index 18920c6..b1fe1b6 100644
--- a/python/tvm/contrib/debugger/debug_result.py
+++ b/python/tvm/contrib/debugger/debug_result.py
@@ -53,9 +53,9 @@ class DebugResult(object):
         self._dump_path = dump_path
         self._output_tensor_list = []
         self._time_list = []
-        self._parse_graph(graph_json)
+        json_obj = self._parse_graph(graph_json)
         # dump the json information
-        self.dump_graph_json(graph_json)
+        self._dump_graph_json(json_obj)
 
     def _parse_graph(self, graph_json):
         """Parse and extract the JSON graph and update the nodes, shapes and 
dltype.
@@ -70,12 +70,12 @@ class DebugResult(object):
         self._shapes_list = json_obj['attrs']['shape']
         self._dtype_list = json_obj['attrs']['dltype']
         self._update_graph_json()
+        return json_obj
 
     def _update_graph_json(self):
         """update the nodes_list with name, shape and data type,
         for temporarily storing the output.
         """
-
         nodes_len = len(self._nodes_list)
         for i in range(nodes_len):
             node = self._nodes_list[i]
@@ -192,7 +192,7 @@ class DebugResult(object):
         with open(os.path.join(self._dump_path, CHROME_TRACE_FILE_NAME), "w") 
as trace_f:
             json.dump(result, trace_f)
 
-    def dump_graph_json(self, graph):
+    def _dump_graph_json(self, graph):
         """Dump json formatted graph.
 
         Parameters
diff --git a/tests/python/unittest/test_runtime_graph_debug.py 
b/tests/python/unittest/test_runtime_graph_debug.py
index 658d9eb..ce47b16 100644
--- a/tests/python/unittest/test_runtime_graph_debug.py
+++ b/tests/python/unittest/test_runtime_graph_debug.py
@@ -14,11 +14,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import json
 import os
 import tvm
 from tvm import te
 import numpy as np
-import json
 from tvm import rpc
 from tvm.contrib import util
 from tvm.contrib.debugger import debug_runtime as graph_runtime
@@ -75,7 +75,16 @@ def test_graph_simple():
         assert(len(os.listdir(directory)) == 1)
 
         #verify the file name is proper
-        assert(os.path.exists(os.path.join(directory, GRAPH_DUMP_FILE_NAME)))
+        graph_dump_path = os.path.join(directory, GRAPH_DUMP_FILE_NAME)
+        assert(os.path.exists(graph_dump_path))
+
+        # verify the graph contains some expected keys
+        with open(graph_dump_path) as graph_f:
+            dumped_graph = json.load(graph_f)
+
+        assert isinstance(dumped_graph, dict)
+        for k in ("nodes", "arg_nodes", "node_row_ptr", "heads", "attrs"):
+            assert k in dumped_graph, f"key {k} not in dumped graph {graph!r}"
 
         mod.run()
         #Verify the tensors are dumped

Reply via email to