yongwww commented on a change in pull request #8454:
URL: https://github.com/apache/tvm/pull/8454#discussion_r675347603
##########
File path: python/tvm/relay/frontend/tensorflow2.py
##########
@@ -38,12 +38,20 @@
from .common import infer_type as _infer_type
from .tensorflow_ops import _convert_map as _convert_map_common
-from .tensorflow_ops import _need_prelude_for_shape_inference
+from .tensorflow_ops import _get_more_static_shape_rank
+from .tensorflow2_ops import _convert_map as _convert_map_tf2
+from .tensorflow2_ops import _need_prelude_for_shape_inference
from ..ty import Any
__all__ = ["from_tensorflow"]
+# A map to record tensor list write ops and input tl/tensor indices
+# Value is (index of tensor list, index of written node)
+_tensor_list_write_ops = {
+ "TensorListSetItem": (0, 2),
Review comment:
is there any reason why you chose 0 and 2 here?
##########
File path: python/tvm/relay/frontend/tensorflow2_ops.py
##########
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument, too-many-lines,
len-as-condition, broad-except
+"""Tensorflow2.x to relay converter ops and helper"""
+import tvm
+from tvm.relay.prelude import StaticTensorArrayOps, get_tensor_array_shape
+
+from .. import op as _op
+from ..ty import Any
+from .common import infer_value as _infer_value
+from .common import infer_type as _infer_type
+from .tensorflow_ops import _get_more_static_shape_rank
+
+
+def _infer_type_with_prelude(val, prelude):
+ body = _infer_type(val, prelude.mod)
+ return body.checked_type
+
+
+def _need_prelude_for_shape_inference(op):
+ return "TensorList" in op or "TensorArray" in op
+
+
+def _tensorlist_reserve():
+ def _impl(inputs, attr, params, prelude):
+ dtype_str = attr.get("element_dtype").name
+ elem_shape = _infer_value(inputs[0], params, prelude.mod)
+ elem_shape = tuple(elem_shape.asnumpy().astype("int32").flatten())
+
+ if elem_shape or "shape" in attr:
+ shape = attr["shape"] if "shape" in attr else elem_shape
+ static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str,
shape)
+ static_tensor_array_ops.register()
+ tensor_array_constructor =
static_tensor_array_ops.get_global_var("tensor_array")
+ tensor_array = tensor_array_constructor(inputs[1])
+ else:
+ tensor_array_constructor = prelude.get_global_var("tensor_array",
dtype_str)
+ tensor_array = tensor_array_constructor(inputs[1])
+ return tensor_array
+
+ return _impl
+
+
+def _tensorlist_set_item():
+ def _impl(inputs, attr, params, prelude):
+ dtype_str = attr.get("element_dtype").name
+ input_ta = inputs[0]
+ input_ta_shape = get_tensor_array_shape(input_ta, dtype_str, prelude)
+ input_t_shape = _infer_type_with_prelude(inputs[2], prelude).shape
+ input_rank = len(input_t_shape)
+
+ if input_ta_shape is None:
+ tensor_name = "tensor{}".format(input_rank)
+ tensor_func = prelude.get_tensor_ctor(tensor_name, dtype_str)
+ v = tensor_func(inputs[2])
+ write_func = prelude.get_global_var("tensor_array_write",
dtype_str)
+ out = write_func(input_ta, inputs[1], v)
+ else:
+ static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str,
input_ta_shape)
+ static_tensor_array_ops.register()
+ tensor_func =
static_tensor_array_ops.get_ctor("tensor_constructor")
+ v = tensor_func(inputs[2])
+ # Write tensor with more static shape
+ # convert shape with -1 to any()
+ input_ta_shape_a = []
+ for dim in input_ta_shape:
+ if isinstance(dim, (int, tvm.tir.expr.IntImm)):
+ if dim < 0:
+ input_ta_shape_a.append(Any())
+ else:
+ input_ta_shape_a.append(dim)
+ else:
+ input_ta_shape_a.append(dim)
+ actual_shape = _get_more_static_shape_rank(input_t_shape,
input_ta_shape_a)
+ if actual_shape != input_ta_shape_a:
+ new_shape = []
+ num_any_dim = 0
+ for dim in actual_shape:
+ if not isinstance(dim, int):
+ num_any_dim += 1
+ new_shape.append(dim if isinstance(dim, int) else -1)
+ if num_any_dim <= 1:
+ v = tensor_func(_op.reshape(inputs[2], new_shape))
+ write_func = prelude.get_global_var_static(
+ "tensor_array_write", dtype_str, input_ta_shape_a
+ )
+ out = write_func(input_ta, inputs[1], v)
+ return out
+
+ return _impl
+
+
+def _tensorlist_get_item():
+ def _impl(inputs, attr, params, prelude):
+ dtype_str = attr["element_dtype"].name
+ input_shape = get_tensor_array_shape(inputs[0], dtype_str, prelude)
+
+ if input_shape is None:
+ read_func = prelude.get_global_var("tensor_array_read", dtype_str)
+ out = read_func(inputs[0], _op.take(inputs[1], tvm.relay.const(0)))
+ else:
+ static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str,
input_shape)
+ static_tensor_array_ops.register()
+ read_func =
static_tensor_array_ops.get_global_var("tensor_array_read")
+ out_tensor = read_func(inputs[0], _op.take(inputs[1],
tvm.relay.const(0)))
+ get_data_func =
static_tensor_array_ops.get_global_var("tensor_get_data")
+ out = get_data_func(out_tensor)
+ return out
+
+ return _impl
+
+
+def _tensorlist_stack():
+ def _impl(inputs, attr, params, prelude):
+ dtype_str = attr["element_dtype"].name
+ input_ta_shape = get_tensor_array_shape(inputs[0], dtype_str, prelude)
+
+ if input_ta_shape is None:
+ stack_func = prelude.get_global_var("tensor_array_stack",
dtype_str)
+ out = stack_func(inputs[0])
+ else:
+ static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str,
input_ta_shape)
+ static_tensor_array_ops.register()
+ stack_func = prelude.get_global_var_static(
+ "tensor_array_stack", dtype_str, input_ta_shape
+ )
+ out_tensor = stack_func(inputs[0])
+ out_shape = (Any(),) + input_ta_shape
+ static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str,
out_shape)
+ static_tensor_array_ops.register()
+ get_data_func = prelude.get_global_var_static("tensor_get_data",
dtype_str, out_shape)
+ out = get_data_func(out_tensor)
+
+ return out
+
+ return _impl
+
+
+def _tensorlist_from_tensor():
+ def _impl(inputs, attr, params, prelude):
+ dtype_str = attr["element_dtype"].name
+ input_ta_shape = _infer_type_with_prelude(inputs[0], prelude).shape
+
+ if input_ta_shape is None:
+ unstack_func = prelude.get_global_var("tensor_array_unstack",
dtype_str)
+ out = unstack_func(inputs[0])
+ else:
+ static_tensor_array_ops = StaticTensorArrayOps(prelude, dtype_str,
input_ta_shape)
+ static_tensor_array_ops.register()
+ unstack_func = prelude.get_global_var_static(
+ "tensor_array_unstack", dtype_str, input_ta_shape
+ )
+ out = unstack_func(inputs[0])
+ return out
+
+ return _impl
+
+
+_convert_map = {
+ "TensorListFromTensor": _tensorlist_from_tensor(),
+ "TensorListGetItem": _tensorlist_get_item(),
+ "TensorListReserve": _tensorlist_reserve(),
+ "TensorListSetItem": _tensorlist_set_item(),
+ "TensorListStack": _tensorlist_stack(),
Review comment:
seems there are several other TensorList related ops like
TensorListScatter are not added here, what are the reasons you chose the list
here to support?
##########
File path: python/tvm/relay/frontend/tensorflow2.py
##########
@@ -215,10 +232,134 @@ def from_tensorflow(
)
return func, self._params
+ def _analysis_tensor_list_op(
+ self,
+ graph,
+ node,
+ tl_write_nodes,
+ tl_stack_nodes,
+ tl_construct_nodes,
+ sub_func_name="",
+ root_node="",
+ ):
+ if sub_func_name and sub_func_name not in self._sub_input_idx_map:
+ self._sub_input_idx_map[sub_func_name] = {}
+
+ if node.op == "Placeholder":
+ # record placeholder node in sub functions
+ self._sub_map[sub_func_name] = node
+ self._sub_input_idx_map[sub_func_name][node.name] = len(
+ self._sub_input_idx_map[sub_func_name]
+ )
+
+ if node.op.startswith("TensorList"):
+ if is_tensor_list_constuctor(node):
+ tl_construct_nodes.append(node)
+ else:
+ for tl_write_name, idx in _tensor_list_write_ops.items():
+ if node.op.startswith(tl_write_name):
+ tl_write_nodes.append((node, idx, sub_func_name,
root_node))
+ if node.op.startswith("TensorListStack"):
+ tl_stack_nodes.append(node)
+ elif node.op.startswith("StatelessWhile"):
+ root_node = node.name
+ cond_fn_name, body_fn_name = [
+ parse_attr(node.attr).get(x).name for x in ["cond", "body"]
+ ]
+ for fn_name in [cond_fn_name, body_fn_name]:
+ subfunction = self._gdef_lib[fn_name]
+ sub_func_name = fn_name
+ for sub_node in subfunction.node:
+ # bypass const node
+ if sub_node.op == "Const":
+ continue
+ self._tf_node_map[sub_node.name] = sub_node
+ self._analysis_tensor_list_op(
+ subfunction,
+ sub_node,
+ tl_write_nodes,
+ tl_stack_nodes,
+ tl_construct_nodes,
+ sub_func_name=sub_func_name,
+ root_node=root_node,
+ )
+
+ def _infer_static_shape_stack_node(self, tl_stack_nodes):
+ for stack_node in tl_stack_nodes:
+ if len(stack_node.input) < 2:
+ # Stack node does not have shape
+ continue
+ input_shape_name = stack_node.input[1].split(":")[0]
+ input_shape_node = self._tf_node_map[input_shape_name]
+ stack = [self._tf_node_map[stack_node.input[0].split(":")[0]]]
+ in_idx = -1
+ while stack:
+ cnode = stack.pop(0)
+ if not cnode.op.startswith("TensorList"):
+ if in_idx and cnode.op.startswith("StatelessWhile"):
+
stack.append(self._tf_node_map[cnode.input[in_idx].split(":")[0]])
+ else:
+ for iname in cnode.input:
+ if
self._tf_node_map[iname.split(":")[0]].op.startswith(
+ "StatelessWhile"
+ ):
+ # identify input index based on output index
+ if iname.split(":")[1]:
+ in_idx = int(iname.split(":")[1])
+
stack.append(self._tf_node_map[iname.split(":")[0]])
+ # identify the corresponding constructor node and add shape to
_tensor_list_shapes
+ elif cnode.name != stack_node.name:
+ if is_tensor_list_constuctor(cnode):
+ shape_attr = parse_attr(input_shape_node.attr)
+ if "value" not in shape_attr:
Review comment:
just for curiosity, for the case 'value' not in shape_attr, does that
mean it is a input node without shape?
##########
File path: tests/python/frontend/tensorflow2/test_sequential_models.py
##########
@@ -109,5 +109,60 @@ def maxpool_batchnorm_model(input_shape, pool_size=(2, 2)):
run_sequential_model(maxpool_batchnorm_model, input_shape=(1, 32, 32, 3))
+def test_tensorlist_stack_model():
+ def tensorlist_stack_model(input_shape):
+ class TensorArrayStackLayer(tf.keras.layers.Layer):
+ def __init__(self):
+ super().__init__()
+
+ def call(self, inputs):
+ inputs = tf.squeeze(inputs)
+ outputs = tf.TensorArray(
+ tf.float32,
+ size=inputs.shape[0],
+ infer_shape=False,
+ element_shape=inputs.shape[1:],
+ )
+ outputs = outputs.unstack(inputs)
+
+ return outputs.stack()
+
+ input_shape = (3, 32)
+ model = tf.keras.Sequential(
+ [tf.keras.layers.Input(shape=input_shape, batch_size=1),
TensorArrayStackLayer()]
+ )
+ return model
+
+ run_sequential_model(tensorlist_stack_model, input_shape=(3, 32))
+
+
+def test_tensorlist_read_model():
+ def tensorlist_read_model(input_shape):
+ class TensorArrayReadLayer(tf.keras.layers.Layer):
+ def __init__(self):
+ super().__init__()
+
+ def call(self, inputs):
+ inputs = tf.squeeze(inputs)
+ outputs = tf.TensorArray(
+ tf.float32,
+ size=inputs.shape[0],
+ infer_shape=False,
+ element_shape=inputs.shape[1:],
+ )
+ for i in range(inputs.shape[0]):
+ outputs = outputs.write(i, inputs[i, :])
+
+ return outputs.read(0)
+
+ input_shape = (3, 32)
+ model = tf.keras.Sequential(
+ [tf.keras.layers.Input(shape=input_shape, batch_size=1),
TensorArrayReadLayer()]
+ )
+ return model
+
+ run_sequential_model(tensorlist_read_model, input_shape=(3, 32))
Review comment:
maybe adding a test case with a public model that contains tensorlist
ops
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]