zxy844288792 commented on a change in pull request #8454:
URL: https://github.com/apache/tvm/pull/8454#discussion_r675704461
##########
File path: python/tvm/relay/frontend/tensorflow2.py
##########
@@ -215,10 +232,134 @@ def from_tensorflow(
)
return func, self._params
+ def _analysis_tensor_list_op(
+ self,
+ graph,
+ node,
+ tl_write_nodes,
+ tl_stack_nodes,
+ tl_construct_nodes,
+ sub_func_name="",
+ root_node="",
+ ):
+ if sub_func_name and sub_func_name not in self._sub_input_idx_map:
+ self._sub_input_idx_map[sub_func_name] = {}
+
+ if node.op == "Placeholder":
+ # record placeholder node in sub functions
+ self._sub_map[sub_func_name] = node
+ self._sub_input_idx_map[sub_func_name][node.name] = len(
+ self._sub_input_idx_map[sub_func_name]
+ )
+
+ if node.op.startswith("TensorList"):
+ if is_tensor_list_constuctor(node):
+ tl_construct_nodes.append(node)
+ else:
+ for tl_write_name, idx in _tensor_list_write_ops.items():
+ if node.op.startswith(tl_write_name):
+ tl_write_nodes.append((node, idx, sub_func_name,
root_node))
+ if node.op.startswith("TensorListStack"):
+ tl_stack_nodes.append(node)
+ elif node.op.startswith("StatelessWhile"):
+ root_node = node.name
+ cond_fn_name, body_fn_name = [
+ parse_attr(node.attr).get(x).name for x in ["cond", "body"]
+ ]
+ for fn_name in [cond_fn_name, body_fn_name]:
+ subfunction = self._gdef_lib[fn_name]
+ sub_func_name = fn_name
+ for sub_node in subfunction.node:
+ # bypass const node
+ if sub_node.op == "Const":
+ continue
+ self._tf_node_map[sub_node.name] = sub_node
+ self._analysis_tensor_list_op(
+ subfunction,
+ sub_node,
+ tl_write_nodes,
+ tl_stack_nodes,
+ tl_construct_nodes,
+ sub_func_name=sub_func_name,
+ root_node=root_node,
+ )
+
+ def _infer_static_shape_stack_node(self, tl_stack_nodes):
+ for stack_node in tl_stack_nodes:
+ if len(stack_node.input) < 2:
+ # Stack node does not have shape
+ continue
+ input_shape_name = stack_node.input[1].split(":")[0]
+ input_shape_node = self._tf_node_map[input_shape_name]
+ stack = [self._tf_node_map[stack_node.input[0].split(":")[0]]]
+ in_idx = -1
+ while stack:
+ cnode = stack.pop(0)
+ if not cnode.op.startswith("TensorList"):
+ if in_idx and cnode.op.startswith("StatelessWhile"):
+
stack.append(self._tf_node_map[cnode.input[in_idx].split(":")[0]])
+ else:
+ for iname in cnode.input:
+ if
self._tf_node_map[iname.split(":")[0]].op.startswith(
+ "StatelessWhile"
+ ):
+ # identify input index based on output index
+ if iname.split(":")[1]:
+ in_idx = int(iname.split(":")[1])
+
stack.append(self._tf_node_map[iname.split(":")[0]])
+ # identify the corresponding constructor node and add shape to
_tensor_list_shapes
+ elif cnode.name != stack_node.name:
+ if is_tensor_list_constuctor(cnode):
+ shape_attr = parse_attr(input_shape_node.attr)
+ if "value" not in shape_attr:
Review comment:
yes, then we will treat it as dynamic shape. It will still work as a
unit test but in real model, it might cause problem if following by some ops
like conv2D
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]