This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e3dda2398f [TFLite][Frontend] Add expected IRModule checks for conv2d, 
pool2d, and batch_matmul tests (#18970)
e3dda2398f is described below

commit e3dda2398f58dd4e1d3a5deb61761d96489077d5
Author: Shushi Hong <[email protected]>
AuthorDate: Sat Apr 4 00:57:03 2026 -0400

    [TFLite][Frontend] Add expected IRModule checks for conv2d, pool2d, and 
batch_matmul tests (#18970)
    
    Add expected IRModule checks for conv2d, pool2d, and batch_matmul tests
---
 tests/python/relax/test_frontend_tflite.py | 227 ++++++++++++++++++++++++-----
 1 file changed, 194 insertions(+), 33 deletions(-)

diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index e7d81cf5fe..f1f91002c4 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -700,22 +700,12 @@ def test_reduce(tf_op, relax_op, axis, out_shape):
     verify(TfInput, Expected)
 
 
[email protected](
-    "data, kernel, data_format, strides, padding",
-    [
-        # Tf on CI (CPU) support only NHWC
-        ((1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "SAME"),
-        ((1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "VALID"),
-        # ((1, 32, 128, 128), (3, 3, 32, 32), "NCHW", (1, 1, 1, 1), "SAME"),
-        # ((1, 32, 128, 128), (3, 3, 32, 32), "NCHW", (1, 1, 1, 1), "VALID"),
-    ],
-)
-def test_conv2d(data, kernel, data_format, strides, padding):
+def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, 
padding):
     class Conv2DModule(tf.Module):
         @tf.function(
             input_signature=[
-                tf.TensorSpec(shape=data, dtype=tf.float32),
-                tf.TensorSpec(shape=kernel, dtype=tf.float32),
+                tf.TensorSpec(shape=data_shape, dtype=tf.float32),
+                tf.TensorSpec(shape=kernel_shape, dtype=tf.float32),
             ]
         )
         def func(self, data, kernel):
@@ -727,39 +717,180 @@ def test_conv2d(data, kernel, data_format, strides, 
padding):
                 padding=padding,
             )
 
-    verify(Conv2DModule)
+    return Conv2DModule
 
 
[email protected](
-    "pool",
-    [tf.nn.avg_pool2d, tf.nn.max_pool2d],
-)
[email protected](
-    "data, kernel, data_format, strides, padding",
-    [
-        # Tf on CI (CPU) support only NHWC
-        ((1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME"),
-        ((1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID"),
-        # ((1, 32, 128, 128), (3, 3), "NCHW", (1, 1, 1, 1), "SAME"),
-        # ((1, 32, 128, 128), (3, 3), "NCHW", (1, 1, 1, 1), "VALID"),
-    ],
-)
-def test_pool_2d(pool, data, kernel, data_format, strides, padding):
+def test_conv2d_same():
+    Conv2DModule = _make_conv2d_module(
+        (1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "SAME"
+    )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            data: R.Tensor((1, 128, 128, 32), dtype="float32"),
+            kernel: R.Tensor((3, 3, 32, 32), dtype="float32"),
+        ) -> R.Tensor((1, 128, 128, 32), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims(
+                    kernel, axes=[3, 0, 1, 2]
+                )
+                lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = 
R.permute_dims(
+                    lv, axes=[1, 2, 3, 0]
+                )
+                lv2: R.Tensor((1, 128, 128, 32), dtype="float32") = 
R.nn.conv2d(
+                    data,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[1, 1, 1, 1],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="HWIO",
+                    out_layout="NHWC",
+                    out_dtype="void",
+                )
+                gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.add(
+                    lv2, R.const(np.zeros((32,), dtype="float32"))
+                )
+                R.output(gv)
+            return gv
+
+    verify(Conv2DModule, Expected)
+
+
+def test_conv2d_valid():
+    Conv2DModule = _make_conv2d_module(
+        (1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "VALID"
+    )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            data: R.Tensor((1, 128, 128, 32), dtype="float32"),
+            kernel: R.Tensor((3, 3, 32, 32), dtype="float32"),
+        ) -> R.Tensor((1, 126, 126, 32), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims(
+                    kernel, axes=[3, 0, 1, 2]
+                )
+                lv1: R.Tensor((3, 3, 32, 32), dtype="float32") = 
R.permute_dims(
+                    lv, axes=[1, 2, 3, 0]
+                )
+                lv2: R.Tensor((1, 126, 126, 32), dtype="float32") = 
R.nn.conv2d(
+                    data,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[0, 0, 0, 0],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NHWC",
+                    kernel_layout="HWIO",
+                    out_layout="NHWC",
+                    out_dtype="void",
+                )
+                gv: R.Tensor((1, 126, 126, 32), dtype="float32") = R.add(
+                    lv2, R.const(np.zeros((32,), dtype="float32"))
+                )
+                R.output(gv)
+            return gv
+
+    verify(Conv2DModule, Expected)
+
+
+def _make_pool2d_module(pool, data_shape, ksize, data_format, strides, 
padding):
     class Pool2DModule(tf.Module):
         @tf.function(
             input_signature=[
-                tf.TensorSpec(shape=data, dtype=tf.float32),
+                tf.TensorSpec(shape=data_shape, dtype=tf.float32),
             ]
         )
         def func(self, data):
             return pool(
                 input=data,
-                ksize=kernel,
+                ksize=ksize,
                 data_format=data_format,
                 strides=strides,
                 padding=padding,
             )
 
+    return Pool2DModule
+
+
+def test_avg_pool2d_same():
+    Pool2DModule = _make_pool2d_module(
+        tf.nn.avg_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), 
"SAME"
+    )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            data: R.Tensor((1, 128, 128, 32), dtype="float32"),
+        ) -> R.Tensor((1, 128, 128, 32), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((1, 128, 128, 32), dtype="float32") = 
R.nn.avg_pool2d(
+                    data,
+                    pool_size=[2, 2],
+                    strides=[1, 1],
+                    dilation=[1, 1],
+                    padding=[0, 0, 1, 1],
+                    ceil_mode=False,
+                    count_include_pad=False,
+                    layout="NHWC",
+                    out_layout="NHWC",
+                )
+                R.output(gv)
+            return gv
+
+    verify(Pool2DModule, Expected)
+
+
+def test_avg_pool2d_valid():
+    Pool2DModule = _make_pool2d_module(
+        tf.nn.avg_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), 
"VALID"
+    )
+    verify(Pool2DModule)
+
+
+def test_max_pool2d_same():
+    Pool2DModule = _make_pool2d_module(
+        tf.nn.max_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), 
"SAME"
+    )
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            data: R.Tensor((1, 128, 128, 32), dtype="float32"),
+        ) -> R.Tensor((1, 128, 128, 32), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((1, 128, 128, 32), dtype="float32") = 
R.nn.max_pool2d(
+                    data,
+                    pool_size=[2, 2],
+                    strides=[1, 1],
+                    dilation=[1, 1],
+                    padding=[0, 0, 1, 1],
+                    ceil_mode=False,
+                    layout="NHWC",
+                    out_layout="NHWC",
+                )
+                R.output(gv)
+            return gv
+
+    verify(Pool2DModule, Expected)
+
+
+def test_max_pool2d_valid():
+    Pool2DModule = _make_pool2d_module(
+        tf.nn.max_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), 
"VALID"
+    )
     verify(Pool2DModule)
 
 
@@ -836,7 +967,21 @@ def test_batch_matmul():
         def func(self, x, y):
             return tf.matmul(x, y)
 
-    verify(BatchMatMul)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 4), dtype="float32"),
+            y: R.Tensor((2, 4, 5), dtype="float32"),
+        ) -> R.Tensor((2, 3, 5), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((2, 3, 5), dtype="float32") = R.matmul(x, y, 
out_dtype="void")
+                gv: R.Tensor((2, 3, 5), dtype="float32") = R.reshape(lv, 
R.shape([2, 3, 5]))
+                R.output(gv)
+            return gv
+
+    verify(BatchMatMul, Expected)
 
 
 def test_batch_matmul_adj():
@@ -850,7 +995,23 @@ def test_batch_matmul_adj():
         def func(self, x, y):
             return tf.matmul(x, y, transpose_a=True, transpose_b=True)
 
-    verify(BatchMatMulAdj)
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 4, 3), dtype="float32"),
+            y: R.Tensor((2, 5, 4), dtype="float32"),
+        ) -> R.Tensor((2, 3, 5), dtype="float32"):
+            R.func_attr({"num_input": 2})
+            with R.dataflow():
+                lv: R.Tensor((2, 3, 4), dtype="float32") = R.permute_dims(x, 
axes=[0, 2, 1])
+                lv1: R.Tensor((2, 4, 5), dtype="float32") = R.permute_dims(y, 
axes=[0, 2, 1])
+                lv2: R.Tensor((2, 3, 5), dtype="float32") = R.matmul(lv, lv1, 
out_dtype="void")
+                gv: R.Tensor((2, 3, 5), dtype="float32") = R.reshape(lv2, 
R.shape([2, 3, 5]))
+                R.output(gv)
+            return gv
+
+    verify(BatchMatMulAdj, Expected)
 
 
 if __name__ == "__main__":

Reply via email to