comaniac commented on a change in pull request #8454:
URL: https://github.com/apache/tvm/pull/8454#discussion_r673343332



##########
File path: python/tvm/relay/frontend/tensorflow2.py
##########
@@ -66,6 +74,12 @@ def set_span(sym, node_name):
     return sym
 
 
+def is_tensor_list_constuctor(tf_node):
+    """Check whether is tensor list constructor node."""
+    tl_name = "TensorListReserve"
+    return tf_node.op == tl_name

Review comment:
       ```suggestion
       return tf_node.op == "TensorListReserve"
   ```

##########
File path: python/tvm/relay/frontend/tensorflow2.py
##########
@@ -215,10 +233,130 @@ def from_tensorflow(
         )
         return func, self._params
 
+    def _analysis_tensor_list_op(
+        self,
+        graph,
+        node,
+        tl_write_nodes,
+        tl_stack_nodes,
+        tl_construct_nodes,
+        sub_func_name="",
+        root_node="",
+    ):
+        if sub_func_name and sub_func_name not in self._sub_input_idx_map:
+            self._sub_input_idx_map[sub_func_name] = {}
+        if node.op == "Placeholder":
+            # record placeholder node in sub functions
+            self._sub_map[sub_func_name] = node
+            self._sub_input_idx_map[sub_func_name][node.name] = len(
+                self._sub_input_idx_map[sub_func_name]
+            )
+        if node.op.startswith("TensorList"):
+            if is_tensor_list_constuctor(node):
+                tl_construct_nodes.append(node)
+            else:
+                for tl_write_name, idx in _tensor_list_write_ops.items():
+                    if node.op.startswith(tl_write_name):
+                        tl_write_nodes.append((node, idx, sub_func_name, 
root_node))
+                if node.op.startswith("TensorListStack"):
+                    tl_stack_nodes.append(node)
+        elif node.op.startswith("StatelessWhile"):
+            root_node = node.name
+            cond_fn_name, body_fn_name = [
+                parse_attr(node.attr).get(x).name for x in ["cond", "body"]
+            ]
+            for fn_name in [cond_fn_name, body_fn_name]:
+                subfunction = self._gdef_lib[fn_name]
+                sub_func_name = fn_name
+                for sub_node in subfunction.node:
+                    # bypass const node
+                    if sub_node.op == "Const":
+                        continue
+                    self._tf_node_map[sub_node.name] = sub_node
+                    self._analysis_tensor_list_op(
+                        subfunction,
+                        sub_node,
+                        tl_write_nodes,
+                        tl_stack_nodes,
+                        tl_construct_nodes,
+                        sub_func_name=sub_func_name,
+                        root_node=root_node,
+                    )
+
+    def _infer_static_shape_stack_node(self, tl_stack_nodes):
+        for stack_node in tl_stack_nodes:
+            if len(stack_node.input) < 2:
+                # Stack node does not have shape
+                continue
+            input_shape_name = stack_node.input[1].split(":")[0]
+            input_shape_node = self._tf_node_map[input_shape_name]
+            stack = [self._tf_node_map[stack_node.input[0].split(":")[0]]]
+            in_idx = -1
+            while stack:
+                cnode = stack.pop(0)
+                if not cnode.op.startswith("TensorList"):
+                    if in_idx and cnode.op.startswith("StatelessWhile"):
+                        
stack.append(self._tf_node_map[cnode.input[in_idx].split(":")[0]])
+                    else:
+                        for iname in cnode.input:
+                            if 
self._tf_node_map[iname.split(":")[0]].op.startswith(
+                                "StatelessWhile"
+                            ):
+                                # identify input index based on output index
+                                if iname.split(":")[1]:
+                                    in_idx = int(iname.split(":")[1])
+                            
stack.append(self._tf_node_map[iname.split(":")[0]])
+                elif cnode.name != stack_node.name:

Review comment:
       Add a comment to explain this.

##########
File path: python/tvm/relay/frontend/tensorflow2.py
##########
@@ -325,12 +475,30 @@ def _convert_operator(self, graph, op_name, node_name, 
inputs, attrs):
                 sym = _convert_map_common[op_name](inputs, attrs, 
self._params, self._prelude)
             else:
                 sym = _convert_map_common[op_name](inputs, attrs, 
self._params, self._module.mod)
+        elif op_name in _convert_map_tf2:

Review comment:
       Add an assert to make sure the ops are exclusive.

##########
File path: python/tvm/relay/frontend/tensorflow2.py
##########
@@ -215,10 +233,130 @@ def from_tensorflow(
         )
         return func, self._params
 
+    def _analysis_tensor_list_op(
+        self,
+        graph,
+        node,
+        tl_write_nodes,
+        tl_stack_nodes,
+        tl_construct_nodes,
+        sub_func_name="",
+        root_node="",
+    ):
+        if sub_func_name and sub_func_name not in self._sub_input_idx_map:
+            self._sub_input_idx_map[sub_func_name] = {}
+        if node.op == "Placeholder":
+            # record placeholder node in sub functions
+            self._sub_map[sub_func_name] = node
+            self._sub_input_idx_map[sub_func_name][node.name] = len(
+                self._sub_input_idx_map[sub_func_name]
+            )
+        if node.op.startswith("TensorList"):

Review comment:
       Add a blank line.

##########
File path: python/tvm/relay/frontend/tensorflow2.py
##########
@@ -215,10 +233,130 @@ def from_tensorflow(
         )
         return func, self._params
 
+    def _analysis_tensor_list_op(
+        self,
+        graph,
+        node,
+        tl_write_nodes,
+        tl_stack_nodes,
+        tl_construct_nodes,
+        sub_func_name="",
+        root_node="",
+    ):
+        if sub_func_name and sub_func_name not in self._sub_input_idx_map:
+            self._sub_input_idx_map[sub_func_name] = {}
+        if node.op == "Placeholder":
+            # record placeholder node in sub functions
+            self._sub_map[sub_func_name] = node
+            self._sub_input_idx_map[sub_func_name][node.name] = len(
+                self._sub_input_idx_map[sub_func_name]
+            )
+        if node.op.startswith("TensorList"):
+            if is_tensor_list_constuctor(node):
+                tl_construct_nodes.append(node)
+            else:
+                for tl_write_name, idx in _tensor_list_write_ops.items():
+                    if node.op.startswith(tl_write_name):
+                        tl_write_nodes.append((node, idx, sub_func_name, 
root_node))
+                if node.op.startswith("TensorListStack"):
+                    tl_stack_nodes.append(node)
+        elif node.op.startswith("StatelessWhile"):
+            root_node = node.name
+            cond_fn_name, body_fn_name = [
+                parse_attr(node.attr).get(x).name for x in ["cond", "body"]
+            ]
+            for fn_name in [cond_fn_name, body_fn_name]:
+                subfunction = self._gdef_lib[fn_name]
+                sub_func_name = fn_name
+                for sub_node in subfunction.node:
+                    # bypass const node
+                    if sub_node.op == "Const":
+                        continue
+                    self._tf_node_map[sub_node.name] = sub_node
+                    self._analysis_tensor_list_op(
+                        subfunction,
+                        sub_node,
+                        tl_write_nodes,
+                        tl_stack_nodes,
+                        tl_construct_nodes,
+                        sub_func_name=sub_func_name,
+                        root_node=root_node,
+                    )
+
+    def _infer_static_shape_stack_node(self, tl_stack_nodes):
+        for stack_node in tl_stack_nodes:
+            if len(stack_node.input) < 2:
+                # Stack node does not have shape
+                continue
+            input_shape_name = stack_node.input[1].split(":")[0]
+            input_shape_node = self._tf_node_map[input_shape_name]
+            stack = [self._tf_node_map[stack_node.input[0].split(":")[0]]]
+            in_idx = -1
+            while stack:
+                cnode = stack.pop(0)
+                if not cnode.op.startswith("TensorList"):
+                    if in_idx and cnode.op.startswith("StatelessWhile"):
+                        
stack.append(self._tf_node_map[cnode.input[in_idx].split(":")[0]])
+                    else:
+                        for iname in cnode.input:
+                            if 
self._tf_node_map[iname.split(":")[0]].op.startswith(
+                                "StatelessWhile"
+                            ):
+                                # identify input index based on output index
+                                if iname.split(":")[1]:
+                                    in_idx = int(iname.split(":")[1])
+                            
stack.append(self._tf_node_map[iname.split(":")[0]])
+                elif cnode.name != stack_node.name:
+                    if is_tensor_list_constuctor(cnode):
+                        shape_attr = parse_attr(input_shape_node.attr)
+                        if "value" not in shape_attr:
+                            continue
+                        raw_elem_shape = 
tensor_util.MakeNdarray(shape_attr["value"])
+                        elem_shape = []
+                        for dim in raw_elem_shape:
+                            if dim < 0:
+                                elem_shape.append(Any())
+                            else:
+                                elem_shape.append(int(dim))
+                        self._tensor_list_shapes[cnode.name] = elem_shape
+                    break
+
+    def _infer_static_shape_write_node(self, tl_write_nodes):
+        for item in tl_write_nodes:
+            wnode = item[0]
+            ta_idx, inode_idx = item[1]
+            sub_func_name = item[2]
+            root_name = item[3]
+            stack = [self._tf_node_map[wnode.input[ta_idx].split(":")[0]]]
+            while stack:
+                cnode = stack.pop(0)
+
+                if not cnode.op.startswith("TensorList"):
+                    if cnode.op == "Placeholder" and sub_func_name:
+                        # need to map subfunction
+                        input_idx = 
self._sub_input_idx_map[sub_func_name][cnode.name]
+                        stack.append(
+                            self._tf_node_map[
+                                
self._tf_node_map[root_name].input[input_idx].split(":")[0]
+                            ]
+                        )
+                    else:
+                        for iname in cnode.input:
+                            
stack.append(self._tf_node_map[iname.split(":")[0]])
+                elif cnode.name != wnode.name:

Review comment:
       ditto




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to