This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 482e341 Fix JSON graph dumping. (#5591)
482e341 is described below
commit 482e34107054a08324b29d4078dfaecbe3c68430
Author: Andrew Reusch <[email protected]>
AuthorDate: Wed May 13 18:20:26 2020 -0700
Fix JSON graph dumping. (#5591)
* Previously this function placed a JSON-escaped string containing
the JSON-encoded graph.
---
python/tvm/contrib/debugger/debug_result.py | 8 ++++----
tests/python/unittest/test_runtime_graph_debug.py | 13 +++++++++++--
2 files changed, 15 insertions(+), 6 deletions(-)
diff --git a/python/tvm/contrib/debugger/debug_result.py
b/python/tvm/contrib/debugger/debug_result.py
index 18920c6..b1fe1b6 100644
--- a/python/tvm/contrib/debugger/debug_result.py
+++ b/python/tvm/contrib/debugger/debug_result.py
@@ -53,9 +53,9 @@ class DebugResult(object):
self._dump_path = dump_path
self._output_tensor_list = []
self._time_list = []
- self._parse_graph(graph_json)
+ json_obj = self._parse_graph(graph_json)
# dump the json information
- self.dump_graph_json(graph_json)
+ self._dump_graph_json(json_obj)
def _parse_graph(self, graph_json):
"""Parse and extract the JSON graph and update the nodes, shapes and
dltype.
@@ -70,12 +70,12 @@ class DebugResult(object):
self._shapes_list = json_obj['attrs']['shape']
self._dtype_list = json_obj['attrs']['dltype']
self._update_graph_json()
+ return json_obj
def _update_graph_json(self):
"""update the nodes_list with name, shape and data type,
for temporarily storing the output.
"""
-
nodes_len = len(self._nodes_list)
for i in range(nodes_len):
node = self._nodes_list[i]
@@ -192,7 +192,7 @@ class DebugResult(object):
with open(os.path.join(self._dump_path, CHROME_TRACE_FILE_NAME), "w")
as trace_f:
json.dump(result, trace_f)
- def dump_graph_json(self, graph):
+ def _dump_graph_json(self, graph):
"""Dump json formatted graph.
Parameters
diff --git a/tests/python/unittest/test_runtime_graph_debug.py
b/tests/python/unittest/test_runtime_graph_debug.py
index 658d9eb..ce47b16 100644
--- a/tests/python/unittest/test_runtime_graph_debug.py
+++ b/tests/python/unittest/test_runtime_graph_debug.py
@@ -14,11 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import json
import os
import tvm
from tvm import te
import numpy as np
-import json
from tvm import rpc
from tvm.contrib import util
from tvm.contrib.debugger import debug_runtime as graph_runtime
@@ -75,7 +75,16 @@ def test_graph_simple():
assert(len(os.listdir(directory)) == 1)
#verify the file name is proper
- assert(os.path.exists(os.path.join(directory, GRAPH_DUMP_FILE_NAME)))
+ graph_dump_path = os.path.join(directory, GRAPH_DUMP_FILE_NAME)
+ assert(os.path.exists(graph_dump_path))
+
+ # verify the graph contains some expected keys
+ with open(graph_dump_path) as graph_f:
+ dumped_graph = json.load(graph_f)
+
+ assert isinstance(dumped_graph, dict)
+ for k in ("nodes", "arg_nodes", "node_row_ptr", "heads", "attrs"):
+ assert k in dumped_graph, f"key {k} not in dumped graph {graph!r}"
mod.run()
#Verify the tensors are dumped