This is an automated email from the ASF dual-hosted git repository.
zhic pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 7a41971 Fix tf parser (#5794)
7a41971 is described below
commit 7a419718c121164fc260864014e1d0d81f556949
Author: Yao Wang <[email protected]>
AuthorDate: Fri Jun 12 20:32:46 2020 -0700
Fix tf parser (#5794)
---
python/tvm/relay/frontend/tensorflow.py | 12 ++++--------
python/tvm/relay/frontend/tensorflow_parser.py | 10 ++++++++--
2 files changed, 12 insertions(+), 10 deletions(-)
diff --git a/python/tvm/relay/frontend/tensorflow.py
b/python/tvm/relay/frontend/tensorflow.py
index 5778b25..af09877 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -1322,14 +1322,10 @@ def _shape():
def _fill():
def _impl(inputs, attr, params, mod):
- output_shape = attr['_output_shapes'][0]
- # Output shape must be defined to avoid errors. If any axis is not, we
must
- # try to compute its shape.
- if output_shape is None or -1 in output_shape:
- try:
- output_shape = _expr.Constant(_infer_value(inputs[0], params,
mod))
- except Exception:
- output_shape = inputs[0]
+ try:
+ output_shape = _infer_value(inputs[0], params,
mod).asnumpy().tolist()
+ except Exception:
+ output_shape = inputs[0]
return _op.full(inputs[1], output_shape, attr['T'].name)
return _impl
diff --git a/python/tvm/relay/frontend/tensorflow_parser.py
b/python/tvm/relay/frontend/tensorflow_parser.py
index fdbb876..771aed0 100644
--- a/python/tvm/relay/frontend/tensorflow_parser.py
+++ b/python/tvm/relay/frontend/tensorflow_parser.py
@@ -30,6 +30,10 @@ class TFParser(object):
model_dir : tensorflow frozen pb file or a directory that contains saved
model or checkpoints.
+ outputs : List of output tensor names (Optional)
+ Optional output node names. This will be protected for saved model
+ when we do remove training nodes.
+
Examples
--------
.. code-block:: python
@@ -38,11 +42,12 @@ class TFParser(object):
graphdef = parser.parse()
"""
- def __init__(self, model_dir):
+ def __init__(self, model_dir, outputs=None):
from tensorflow.core.framework import graph_pb2
self._tmp_dir = util.tempdir()
self._model_dir = model_dir
self._graph = graph_pb2.GraphDef()
+ self._outputs = outputs or []
def _set_graph(self, graph):
"""Set Graph"""
@@ -128,7 +133,8 @@ class TFParser(object):
output_graph_def = graph_pb2.GraphDef()
with open(output_graph_filename, "rb") as f:
output_graph_def.ParseFromString(f.read())
- output_graph_def =
graph_util.remove_training_nodes(output_graph_def)
+ output_graph_def =
graph_util.remove_training_nodes(output_graph_def,
+
protected_nodes=self._outputs)
return output_graph_def
def _load_ckpt(self):