zxy844288792 commented on a change in pull request #8454:
URL: https://github.com/apache/tvm/pull/8454#discussion_r675703663



##########
File path: tests/python/frontend/tensorflow2/test_sequential_models.py
##########
@@ -109,5 +109,60 @@ def maxpool_batchnorm_model(input_shape, pool_size=(2, 2)):
     run_sequential_model(maxpool_batchnorm_model, input_shape=(1, 32, 32, 3))
 
 
+def test_tensorlist_stack_model():
+    def tensorlist_stack_model(input_shape):
+        class TensorArrayStackLayer(tf.keras.layers.Layer):
+            def __init__(self):
+                super().__init__()
+
+            def call(self, inputs):
+                inputs = tf.squeeze(inputs)
+                outputs = tf.TensorArray(
+                    tf.float32,
+                    size=inputs.shape[0],
+                    infer_shape=False,
+                    element_shape=inputs.shape[1:],
+                )
+                outputs = outputs.unstack(inputs)
+
+                return outputs.stack()
+
+        input_shape = (3, 32)
+        model = tf.keras.Sequential(
+            [tf.keras.layers.Input(shape=input_shape, batch_size=1), 
TensorArrayStackLayer()]
+        )
+        return model
+
+    run_sequential_model(tensorlist_stack_model, input_shape=(3, 32))
+
+
+def test_tensorlist_read_model():
+    def tensorlist_read_model(input_shape):
+        class TensorArrayReadLayer(tf.keras.layers.Layer):
+            def __init__(self):
+                super().__init__()
+
+            def call(self, inputs):
+                inputs = tf.squeeze(inputs)
+                outputs = tf.TensorArray(
+                    tf.float32,
+                    size=inputs.shape[0],
+                    infer_shape=False,
+                    element_shape=inputs.shape[1:],
+                )
+                for i in range(inputs.shape[0]):
+                    outputs = outputs.write(i, inputs[i, :])
+
+                return outputs.read(0)
+
+        input_shape = (3, 32)
+        model = tf.keras.Sequential(
+            [tf.keras.layers.Input(shape=input_shape, batch_size=1), 
TensorArrayReadLayer()]
+        )
+        return model
+
+    run_sequential_model(tensorlist_read_model, input_shape=(3, 32))

Review comment:
       I have some public models but they will trigger a corner case which I 
will fix in the next Pr. The TensorList op alone works fine.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to