comaniac commented on a change in pull request #8558:
URL: https://github.com/apache/tvm/pull/8558#discussion_r676917031
##########
File path: tests/python/frontend/tensorflow2/test_functional_models.py
##########
@@ -584,5 +584,44 @@ def func(self, x):
)
+def test_tensorlist_stack_unpack():
+ def run_test(elem_shape):
+ class TensorListStack2D(tf.Module):
+ def get_input(self):
+ in_tens = np.ones((1, 3, 4), dtype="float32")
+ return in_tens
+
+ """2D array as input"""
Review comment:
I'm confused about this docstring. Better to move it to a proper place.
##########
File path: python/tvm/relay/prelude.py
##########
@@ -89,26 +89,32 @@ def _get_name_static(canonical, dtype, shape):
shape_str = "scalar"
if canonical == "tensor_t":
return "static_tensor_{}_{}_t".format(dtype, shape_str)
- return "{}_{}_{}".format(canonical, dtype, shape_str)
+ if not batch_dim or canonical == "tensor_constructor" or canonical ==
"tensor_nil":
Review comment:
Be careful that `bool(batch_dim=0)` also gives you false. Better to use
`batch_dim is None` if that's the desire semantic.
```suggestion
if batch_dim is None or canonical in ["tensor_constructor",
"tensor_nil"]:
```
##########
File path: python/tvm/relay/prelude.py
##########
@@ -89,26 +89,32 @@ def _get_name_static(canonical, dtype, shape):
shape_str = "scalar"
if canonical == "tensor_t":
return "static_tensor_{}_{}_t".format(dtype, shape_str)
- return "{}_{}_{}".format(canonical, dtype, shape_str)
+ if not batch_dim or canonical == "tensor_constructor" or canonical ==
"tensor_nil":
+ return "{}_{}_{}".format(canonical, dtype, shape_str)
+ if batch_dim != 1:
+ return "{}_{}_{}".format(canonical, dtype, shape_str)
+ else:
+ return "{}_{}_batch{}_{}".format(canonical, dtype, str(batch_dim),
shape_str)
Review comment:
nit
```suggestion
return "{}_{}_batch{}_{}".format(canonical, dtype, str(batch_dim),
shape_str)
```
##########
File path: python/tvm/relay/prelude.py
##########
@@ -599,8 +613,9 @@ def define_tensor_array_gather(self):
helper_name = self.get_name("tensor_array_gather_helper")
helper_var = self._create_global_var(helper_name)
+ new_axis = Any() if not self.batch_dim or self.batch_dim != 1 else
self.batch_dim
Review comment:
ditto
##########
File path: python/tvm/relay/prelude.py
##########
@@ -262,9 +268,10 @@ def define_tensor_expand_dims(self):
# Note: we set the added axis to be Any() instead of 1 due to
# in stack op, we need to recursively concatenate.
+ new_axis = Any() if not self.batch_dim or self.batch_dim != 1 else
self.batch_dim
Review comment:
ditto
##########
File path: python/tvm/relay/prelude.py
##########
@@ -73,7 +73,7 @@ def get_tensor_array_shape(expr, dtype, prelude):
return None
-def _get_name_static(canonical, dtype, shape):
+def _get_name_static(canonical, dtype, shape, batch_dim=None):
"""Get name for static shape tensor array op corresponding
to the canonical name"""
Review comment:
Better to improve the docstring to explain parameters, especially we now
have `batch_dim`.
##########
File path: tests/python/frontend/tensorflow2/test_functional_models.py
##########
@@ -584,5 +584,44 @@ def func(self, x):
)
+def test_tensorlist_stack_unpack():
+ def run_test(elem_shape):
+ class TensorListStack2D(tf.Module):
+ def get_input(self):
+ in_tens = np.ones((1, 3, 4), dtype="float32")
+ return in_tens
+
+ """2D array as input"""
+
+ @tf.function(input_signature=[tf.TensorSpec(shape=(1, 3, 4),
dtype=tf.float32)])
+ def func(self, x):
+ dtype = tf.float32
+ tl = tf.raw_ops.TensorListReserve(
+ element_shape=elem_shape, num_elements=1,
element_dtype=dtype
+ )
+ tl = tf.raw_ops.TensorListSetItem(input_handle=tl, index=0,
item=x[0, :, :])
+ output = tf.raw_ops.TensorListStack(
+ input_handle=tl, element_shape=elem_shape,
element_dtype=dtype, num_elements=1
+ )
+ output = tf.raw_ops.Unpack(value=output, num=1, axis=0)
+ return output
+
+ run_model_graph(TensorListStack2D)
+ run_func_graph(TensorListStack2D, runtime="vm")
+
+ run_test(
+ (
+ 3,
+ 4,
+ )
+ )
+ run_test(
+ (
+ -1,
+ -1,
+ )
+ )
Review comment:
Can they be in one line each?
##########
File path: python/tvm/relay/prelude.py
##########
@@ -573,20 +580,27 @@ def define_tensor_array_stack(self):
expand_dims_var = self.get_global_var("tensor_expand_dims")
# Register tensor_concatenate for output_shape
+ new_axis = Any() if not self.batch_dim or self.batch_dim != 1 else
self.batch_dim
output_shape = [
- Any(),
+ new_axis,
] + list(self.shape)
-
_, _, output_ops = self._get_adt_by_shape(output_shape)
output_ops.define_tensor_concatenate()
concat_var = output_ops.get_global_var("tensor_concatenate")
tensor_array_expand_dims = self.prelude.map(expand_dims_var,
tensor_array)
- tensors = self.prelude.foldl(
- concat_var,
- self.prelude.hd(tensor_array_expand_dims),
- self.prelude.tl(tensor_array_expand_dims),
- )
+ if self.batch_dim and self.batch_dim == 1:
Review comment:
```suggestion
if self.batch_dim is not None and self.batch_dim == 1:
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]