This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d30e347a69 [Relax] Use weight shape instead of dim in
Embedding.forward (#18621)
d30e347a69 is described below
commit d30e347a69fb85a5fafa62be4bcd579f8d4ca728
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Dec 31 20:10:07 2025 +0800
[Relax] Use weight shape instead of dim in Embedding.forward (#18621)
## Why
The Embedding module's forward method was using `self.dim` redundantly
when reshaping the output, even though this dimension is already
available from `self.weight.shape[1]`.
## How
- Changed `shape=[*x.shape, self.dim]` to `shape=[*x.shape,
self.weight.shape[1]]` in `Embedding.forward`
---
python/tvm/relax/frontend/nn/modules.py | 2 +-
tests/python/relax/test_frontend_nn_modules.py | 21 ++++++++++++++++++++-
2 files changed, 21 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/modules.py
b/python/tvm/relax/frontend/nn/modules.py
index 5ca5f72787..455e42df97 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -743,7 +743,7 @@ class Embedding(Module):
op.reshape(x, shape=[-1]),
axis=0,
),
- shape=[*x.shape, self.dim], # TODO(@junrushao): revisit and
remove self.dim
+ shape=[*x.shape, self.weight.shape[1]],
)
diff --git a/tests/python/relax/test_frontend_nn_modules.py
b/tests/python/relax/test_frontend_nn_modules.py
index e9a4a6f624..8dc4994465 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -365,7 +365,26 @@ def test_group_norm():
assert_structural_equal(tvm_mod["forward"], forward, True)
-def test_embedding():
+def test_embedding_1d():
+ @R.function
+ def forward(
+ x: R.Tensor((4,), dtype="int32"),
+ _io: R.Object,
+ weight: R.Tensor((8, 16), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((4, 16), dtype="float32"), R.Tuple(R.Object)):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ take: R.Tensor((4, 16), dtype="float32") = R.take(weight, x,
axis=0)
+ gv1: R.Tuple(R.Tensor((4, 16), dtype="float32"),
R.Tuple(R.Object)) = take, (_io,)
+ R.output(gv1)
+ return gv1
+
+ mod = modules.Embedding(8, 16, "float32")
+ tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((4,),
"int32")}}, debug=True)
+ assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
+def test_embedding_2d():
@R.function
def forward(
x: R.Tensor((1, 4), dtype="int32"),