This is an automated email from the ASF dual-hosted git repository.
kevinthesun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 4f9d4d7 [TF] Support TupleWrapper as direct ancestor of control flow
ops (#5639)
4f9d4d7 is described below
commit 4f9d4d78a046a8567fd93445c3b5fc2a68081259
Author: lixiaoquan <[email protected]>
AuthorDate: Wed May 27 02:29:29 2020 +0800
[TF] Support TupleWrapper as direct ancestor of control flow ops (#5639)
---
python/tvm/relay/frontend/tensorflow.py | 59 +++++++++-------------
.../frontend/tensorflow/test_control_flow.py | 20 ++++++++
2 files changed, 45 insertions(+), 34 deletions(-)
diff --git a/python/tvm/relay/frontend/tensorflow.py
b/python/tvm/relay/frontend/tensorflow.py
index ab9e9e6..d4b73f9 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -3073,21 +3073,19 @@ class GraphProto(object):
branch = self._branches[node_name_prefix]
false_br = self._backtrack_construct(node.input[0])
true_br = self._backtrack_construct(node.input[1])
- assert len(true_br) == 1
- assert len(false_br) == 1
- branch.true_branch = true_br[0]
- branch.false_branch = false_br[0]
- op = [branch.if_node()]
+ branch.true_branch = true_br
+ branch.false_branch = false_br
+ op = branch.if_node()
if node_name_prefix not in self._while_loop_name_set:
try:
cond_val = np.all(_infer_value(branch.cond,
self._params,
self._mod).asnumpy())
if cond_val:
- op = [branch.true_branch]
+ op = branch.true_branch
else:
- op = [branch.false_branch]
+ op = branch.false_branch
except Exception:
- op = [branch.if_node()]
+ op = branch.if_node()
elif node.op == "Exit":
loop = self._loops[node_name_prefix]
@@ -3113,17 +3111,15 @@ class GraphProto(object):
if exit_number == j:
body_pos = i
break
- op = [_expr.TupleGetItem(expr, body_pos)]
+ op = _expr.TupleGetItem(expr, body_pos)
elif node.op == "Enter":
op = self._backtrack_construct(node.input[0])
elif node.op == "LoopCond":
op = self._backtrack_construct(node.input[0])
- assert len(op) == 1
- self._loops[node_name_prefix].cond = op[0]
+ self._loops[node_name_prefix].cond = op
elif node.op == "Switch":
op = self._backtrack_construct(node.input[0])
cond = self._backtrack_construct(node.input[1])
- assert len(op) == 1
if _in_while_loop(self._control_flow_node_map, node_name_prefix):
if node_name_prefix not in self._loop_var_order:
self._loop_var_order[node_name_prefix] = []
@@ -3132,11 +3128,11 @@ class GraphProto(object):
else:
self._loop_var_order[node_name_prefix].\
append(int(node.name.split("Switch_")[-1]))
- self._loops[node_name_prefix].loop_vars.append(op[0])
+ self._loops[node_name_prefix].loop_vars.append(op)
else:
if node_name_prefix not in self._branches:
self._branches[node_name_prefix] = Branch()
- self._branches[node_name_prefix].cond = cond[0]
+ self._branches[node_name_prefix].cond = cond
elif node.op == "NextIteration":
if node_name_prefix not in self._loop_body_order:
self._loop_body_order[node_name_prefix] = []
@@ -3146,9 +3142,7 @@ class GraphProto(object):
self._loop_body_order[node_name_prefix].\
append(int(node.name.split("NextIteration_")[-1]))
op = self._backtrack_construct(node.input[0])
-
- assert len(op) == 1
- self._loops[node_name_prefix].body.append(op[0])
+ self._loops[node_name_prefix].body.append(op)
else:
raise Exception("Cannot identify control flow operator: " +
"{}".format(node.op))
@@ -3219,10 +3213,10 @@ class GraphProto(object):
op : relay.Expr
Converted relay expression
"""
- node_name = node_name.split(':')[0].split("^")[-1]
+ input_op_name = node_name.split(':')[0].split("^")[-1]
- if node_name not in self._nodes:
- node = self._tf_node_map[node_name]
+ if input_op_name not in self._nodes:
+ node = self._tf_node_map[input_op_name]
attr = self._parse_attr(node.attr)
if node.op in _control_flow_nodes:
@@ -3231,20 +3225,10 @@ class GraphProto(object):
attr,
self._control_flow_node_map)
else:
- attr["_output_shapes"] = self._output_shapes[node_name]
+ attr["_output_shapes"] = self._output_shapes[input_op_name]
attr["_node_name"] = node.name
attr["_target_layout"] = self._layout
- inputs = []
- for iname in node.input:
- in_op = self._backtrack_construct(iname)
- if isinstance(in_op, _expr.TupleWrapper):
- tn = iname.split(':')
- tensor_slot = int(tn[1]) if len(tn) > 1 else 0
- in_op = in_op[tensor_slot]
- else:
- in_op = in_op[0]
-
- inputs.append(in_op)
+ inputs = [self._backtrack_construct(iname) for iname in
node.input]
op = self._convert_operator(node.op, inputs, attr, self._graph)
if isinstance(op, np.ndarray):
@@ -3258,9 +3242,16 @@ class GraphProto(object):
node_hash = s_hash(op) if isinstance(op, _expr.Tuple) else
s_hash(op[0])
self._hash2tfnode[node_hash] = node
- self._nodes[node_name] = op
+ self._nodes[input_op_name] = op
+
+ out = self._nodes[input_op_name]
+
+ if isinstance(out, _expr.TupleWrapper):
+ tn = node_name.split(':')
+ tensor_slot = int(tn[1]) if len(tn) > 1 else 0
+ return out[tensor_slot]
- return self._nodes[node_name]
+ return out[0]
def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
"""Load tensorflow graph which is a python tensorflow graph object into
relay.
diff --git a/tests/python/frontend/tensorflow/test_control_flow.py
b/tests/python/frontend/tensorflow/test_control_flow.py
index 9777a8d..9003527 100644
--- a/tests/python/frontend/tensorflow/test_control_flow.py
+++ b/tests/python/frontend/tensorflow/test_control_flow.py
@@ -21,6 +21,7 @@ try:
tf.disable_v2_behavior()
except ImportError:
import tensorflow as tf
+from tensorflow.python.ops import control_flow_ops
import numpy as np
from tvm import nd
from tvm import relay
@@ -368,6 +369,23 @@ def test_nested_loop_bound():
check_equal(graph, tf_out, {dname: np_data})
+def test_switch():
+ graph = tf.Graph()
+
+ with graph.as_default():
+ data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32')
+ dname = 'data'
+ flag_name = 'flag'
+ data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype,
name=dname)
+ split = tf.split(data, 2, axis=0)
+ flag = tf.placeholder(shape={}, dtype=tf.bool, name=flag_name)
+ output_false, output_true = control_flow_ops.switch(split[1], flag)
+ with tf.Session() as sess:
+ tf_out = sess.run(output_false, feed_dict={data.name: data_np,
flag.name: False})
+
+ check_equal(graph, tf_out, {dname: data_np, flag_name: False})
+
+
if __name__ == "__main__":
# tf.while_loop
test_vanilla_loop()
@@ -390,3 +408,5 @@ if __name__ == "__main__":
test_cond_in_loop()
test_vanilla_loop_bound()
test_nested_loop_bound()
+
+ test_switch()