This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e6538517b0 [Relax][TFLite] Add gather frontend expected IRModule tests
(#19516)
e6538517b0 is described below
commit e6538517b016e5788bec375a36c83fa8c246f400
Author: Wei-Cheng Hsu <[email protected]>
AuthorDate: Fri May 8 16:08:43 2026 +0800
[Relax][TFLite] Add gather frontend expected IRModule tests (#19516)
This adds explicit Expected IRModule coverage for TFLite GATHER and
GATHER_ND frontend conversion.
GATHER_ND uses Relax gather_nd with int64 indices, so the frontend now
casts int32 TFLite indices to int64 before emitting the Relax op. This
keeps the generated module well-typed and matches the expected Relax IR.
Testing:
- `python -m pytest tests/python/relax/test_frontend_tflite.py -k
"gather"`
related to https://github.com/apache/tvm/issues/18971
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 3 ++
tests/python/relax/test_frontend_tflite.py | 58 ++++++++++++++++++++++
2 files changed, 61 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index e66dff8356..f5b88b0c6a 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -1630,6 +1630,9 @@ class OperatorConverter:
indices_dims = len(self._infer_shape(indices))
indices_t = relax.op.permute_dims(indices, axes=[-1] +
list(range(indices_dims - 1)))
+ if indices_type == TensorType.INT32:
+ # Relax gather_nd requires int64 indices.
+ indices_t = relax.op.astype(indices_t, "int64")
out = relax.op.gather_nd(data, indices_t)
return out
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 69e9b290fd..e4c237887e 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1451,6 +1451,64 @@ def test_reverse_v2():
verify(ReverseV2, Expected)
+
+def test_gather():
+ class Gather(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32),
+ tf.TensorSpec(shape=(2,), dtype=tf.int64),
+ ]
+ )
+ def func(self, x, indices):
+ return tf.gather(x, indices, axis=1)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 4), dtype="float32"),
+ indices: R.Tensor((2,), dtype="int64"),
+ ) -> R.Tensor((2, 2, 4), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv: R.Tensor((2,), dtype="int32") = R.astype(indices,
dtype="int32")
+ gv: R.Tensor((2, 2, 4), dtype="float32") = R.take(x, lv,
axis=1, mode="fast")
+ R.output(gv)
+ return gv
+
+ verify(Gather, Expected)
+
+
+def test_gather_nd():
+ class GatherND(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(2, 3, 4), dtype=tf.float32),
+ tf.TensorSpec(shape=(2, 2), dtype=tf.int32),
+ ]
+ )
+ def func(self, x, indices):
+ return tf.gather_nd(x, indices)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 4), dtype="float32"),
+ indices: R.Tensor((2, 2), dtype="int32"),
+ ) -> R.Tensor((2, 4), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ lv: R.Tensor((2, 2), dtype="int32") = R.permute_dims(indices,
axes=[-1, 0])
+ lv1: R.Tensor((2, 2), dtype="int64") = R.astype(lv,
dtype="int64")
+ gv: R.Tensor((2, 4), dtype="float32") = R.gather_nd(x, lv1,
batch_dims=0)
+ R.output(gv)
+ return gv
+
+ verify(GatherND, Expected)
+
+
def _make_conv2d_module(data_shape, kernel_shape, data_format, strides,
padding):
class Conv2DModule(tf.Module):
@tf.function(