This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e3dda2398f [TFLite][Frontend] Add expected IRModule checks for conv2d,
pool2d, and batch_matmul tests (#18970)
e3dda2398f is described below
commit e3dda2398f58dd4e1d3a5deb61761d96489077d5
Author: Shushi Hong <[email protected]>
AuthorDate: Sat Apr 4 00:57:03 2026 -0400
[TFLite][Frontend] Add expected IRModule checks for conv2d, pool2d, and
batch_matmul tests (#18970)
Add expected IRModule checks for conv2d, pool2d, and batch_matmul tests
---
tests/python/relax/test_frontend_tflite.py | 227 ++++++++++++++++++++++++-----
1 file changed, 194 insertions(+), 33 deletions(-)
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index e7d81cf5fe..f1f91002c4 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -700,22 +700,12 @@ def test_reduce(tf_op, relax_op, axis, out_shape):
verify(TfInput, Expected)
[email protected](
- "data, kernel, data_format, strides, padding",
- [
- # Tf on CI (CPU) support only NHWC
- ((1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "SAME"),
- ((1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "VALID"),
- # ((1, 32, 128, 128), (3, 3, 32, 32), "NCHW", (1, 1, 1, 1), "SAME"),
- # ((1, 32, 128, 128), (3, 3, 32, 32), "NCHW", (1, 1, 1, 1), "VALID"),
- ],
-)
-def test_conv2d(data, kernel, data_format, strides, padding):
+def _make_conv2d_module(data_shape, kernel_shape, data_format, strides,
padding):
class Conv2DModule(tf.Module):
@tf.function(
input_signature=[
- tf.TensorSpec(shape=data, dtype=tf.float32),
- tf.TensorSpec(shape=kernel, dtype=tf.float32),
+ tf.TensorSpec(shape=data_shape, dtype=tf.float32),
+ tf.TensorSpec(shape=kernel_shape, dtype=tf.float32),
]
)
def func(self, data, kernel):
@@ -727,39 +717,180 @@ def test_conv2d(data, kernel, data_format, strides,
padding):
padding=padding,
)
- verify(Conv2DModule)
+ return Conv2DModule
[email protected](
- "pool",
- [tf.nn.avg_pool2d, tf.nn.max_pool2d],
-)
[email protected](
- "data, kernel, data_format, strides, padding",
- [
- # Tf on CI (CPU) support only NHWC
- ((1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "SAME"),
- ((1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1), "VALID"),
- # ((1, 32, 128, 128), (3, 3), "NCHW", (1, 1, 1, 1), "SAME"),
- # ((1, 32, 128, 128), (3, 3), "NCHW", (1, 1, 1, 1), "VALID"),
- ],
-)
-def test_pool_2d(pool, data, kernel, data_format, strides, padding):
+def test_conv2d_same():
+ Conv2DModule = _make_conv2d_module(
+ (1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "SAME"
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ data: R.Tensor((1, 128, 128, 32), dtype="float32"),
+ kernel: R.Tensor((3, 3, 32, 32), dtype="float32"),
+ ) -> R.Tensor((1, 128, 128, 32), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims(
+ kernel, axes=[3, 0, 1, 2]
+ )
+ lv1: R.Tensor((3, 3, 32, 32), dtype="float32") =
R.permute_dims(
+ lv, axes=[1, 2, 3, 0]
+ )
+ lv2: R.Tensor((1, 128, 128, 32), dtype="float32") =
R.nn.conv2d(
+ data,
+ lv1,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="HWIO",
+ out_layout="NHWC",
+ out_dtype="void",
+ )
+ gv: R.Tensor((1, 128, 128, 32), dtype="float32") = R.add(
+ lv2, R.const(np.zeros((32,), dtype="float32"))
+ )
+ R.output(gv)
+ return gv
+
+ verify(Conv2DModule, Expected)
+
+
+def test_conv2d_valid():
+ Conv2DModule = _make_conv2d_module(
+ (1, 128, 128, 32), (3, 3, 32, 32), "NHWC", (1, 1, 1, 1), "VALID"
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ data: R.Tensor((1, 128, 128, 32), dtype="float32"),
+ kernel: R.Tensor((3, 3, 32, 32), dtype="float32"),
+ ) -> R.Tensor((1, 126, 126, 32), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv: R.Tensor((32, 3, 3, 32), dtype="float32") = R.permute_dims(
+ kernel, axes=[3, 0, 1, 2]
+ )
+ lv1: R.Tensor((3, 3, 32, 32), dtype="float32") =
R.permute_dims(
+ lv, axes=[1, 2, 3, 0]
+ )
+ lv2: R.Tensor((1, 126, 126, 32), dtype="float32") =
R.nn.conv2d(
+ data,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="HWIO",
+ out_layout="NHWC",
+ out_dtype="void",
+ )
+ gv: R.Tensor((1, 126, 126, 32), dtype="float32") = R.add(
+ lv2, R.const(np.zeros((32,), dtype="float32"))
+ )
+ R.output(gv)
+ return gv
+
+ verify(Conv2DModule, Expected)
+
+
+def _make_pool2d_module(pool, data_shape, ksize, data_format, strides,
padding):
class Pool2DModule(tf.Module):
@tf.function(
input_signature=[
- tf.TensorSpec(shape=data, dtype=tf.float32),
+ tf.TensorSpec(shape=data_shape, dtype=tf.float32),
]
)
def func(self, data):
return pool(
input=data,
- ksize=kernel,
+ ksize=ksize,
data_format=data_format,
strides=strides,
padding=padding,
)
+ return Pool2DModule
+
+
+def test_avg_pool2d_same():
+ Pool2DModule = _make_pool2d_module(
+ tf.nn.avg_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1),
"SAME"
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ data: R.Tensor((1, 128, 128, 32), dtype="float32"),
+ ) -> R.Tensor((1, 128, 128, 32), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((1, 128, 128, 32), dtype="float32") =
R.nn.avg_pool2d(
+ data,
+ pool_size=[2, 2],
+ strides=[1, 1],
+ dilation=[1, 1],
+ padding=[0, 0, 1, 1],
+ ceil_mode=False,
+ count_include_pad=False,
+ layout="NHWC",
+ out_layout="NHWC",
+ )
+ R.output(gv)
+ return gv
+
+ verify(Pool2DModule, Expected)
+
+
+def test_avg_pool2d_valid():
+ Pool2DModule = _make_pool2d_module(
+ tf.nn.avg_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1),
"VALID"
+ )
+ verify(Pool2DModule)
+
+
+def test_max_pool2d_same():
+ Pool2DModule = _make_pool2d_module(
+ tf.nn.max_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1),
"SAME"
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ data: R.Tensor((1, 128, 128, 32), dtype="float32"),
+ ) -> R.Tensor((1, 128, 128, 32), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ gv: R.Tensor((1, 128, 128, 32), dtype="float32") =
R.nn.max_pool2d(
+ data,
+ pool_size=[2, 2],
+ strides=[1, 1],
+ dilation=[1, 1],
+ padding=[0, 0, 1, 1],
+ ceil_mode=False,
+ layout="NHWC",
+ out_layout="NHWC",
+ )
+ R.output(gv)
+ return gv
+
+ verify(Pool2DModule, Expected)
+
+
+def test_max_pool2d_valid():
+ Pool2DModule = _make_pool2d_module(
+ tf.nn.max_pool2d, (1, 128, 128, 32), (2, 2), "NHWC", (1, 1, 1, 1),
"VALID"
+ )
verify(Pool2DModule)
@@ -836,7 +967,21 @@ def test_batch_matmul():
def func(self, x, y):
return tf.matmul(x, y)
- verify(BatchMatMul)
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 4), dtype="float32"),
+ y: R.Tensor((2, 4, 5), dtype="float32"),
+ ) -> R.Tensor((2, 3, 5), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv: R.Tensor((2, 3, 5), dtype="float32") = R.matmul(x, y,
out_dtype="void")
+ gv: R.Tensor((2, 3, 5), dtype="float32") = R.reshape(lv,
R.shape([2, 3, 5]))
+ R.output(gv)
+ return gv
+
+ verify(BatchMatMul, Expected)
def test_batch_matmul_adj():
@@ -850,7 +995,23 @@ def test_batch_matmul_adj():
def func(self, x, y):
return tf.matmul(x, y, transpose_a=True, transpose_b=True)
- verify(BatchMatMulAdj)
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 4, 3), dtype="float32"),
+ y: R.Tensor((2, 5, 4), dtype="float32"),
+ ) -> R.Tensor((2, 3, 5), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv: R.Tensor((2, 3, 4), dtype="float32") = R.permute_dims(x,
axes=[0, 2, 1])
+ lv1: R.Tensor((2, 4, 5), dtype="float32") = R.permute_dims(y,
axes=[0, 2, 1])
+ lv2: R.Tensor((2, 3, 5), dtype="float32") = R.matmul(lv, lv1,
out_dtype="void")
+ gv: R.Tensor((2, 3, 5), dtype="float32") = R.reshape(lv2,
R.shape([2, 3, 5]))
+ R.output(gv)
+ return gv
+
+ verify(BatchMatMulAdj, Expected)
if __name__ == "__main__":