This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d30e347a69 [Relax] Use weight shape instead of dim in 
Embedding.forward (#18621)
d30e347a69 is described below

commit d30e347a69fb85a5fafa62be4bcd579f8d4ca728
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Dec 31 20:10:07 2025 +0800

    [Relax] Use weight shape instead of dim in Embedding.forward (#18621)
    
    ## Why
    
    The Embedding module's forward method was using `self.dim` redundantly
    when reshaping the output, even though this dimension is already
    available from `self.weight.shape[1]`.
    
    ## How
    
    - Changed `shape=[*x.shape, self.dim]` to `shape=[*x.shape,
    self.weight.shape[1]]` in `Embedding.forward`
---
 python/tvm/relax/frontend/nn/modules.py        |  2 +-
 tests/python/relax/test_frontend_nn_modules.py | 21 ++++++++++++++++++++-
 2 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/modules.py 
b/python/tvm/relax/frontend/nn/modules.py
index 5ca5f72787..455e42df97 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -743,7 +743,7 @@ class Embedding(Module):
                 op.reshape(x, shape=[-1]),
                 axis=0,
             ),
-            shape=[*x.shape, self.dim],  # TODO(@junrushao): revisit and 
remove self.dim
+            shape=[*x.shape, self.weight.shape[1]],
         )
 
 
diff --git a/tests/python/relax/test_frontend_nn_modules.py 
b/tests/python/relax/test_frontend_nn_modules.py
index e9a4a6f624..8dc4994465 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -365,7 +365,26 @@ def test_group_norm():
     assert_structural_equal(tvm_mod["forward"], forward, True)
 
 
-def test_embedding():
+def test_embedding_1d():
+    @R.function
+    def forward(
+        x: R.Tensor((4,), dtype="int32"),
+        _io: R.Object,
+        weight: R.Tensor((8, 16), dtype="float32"),
+    ) -> R.Tuple(R.Tensor((4, 16), dtype="float32"), R.Tuple(R.Object)):
+        R.func_attr({"num_input": 2})
+        with R.dataflow():
+            take: R.Tensor((4, 16), dtype="float32") = R.take(weight, x, 
axis=0)
+            gv1: R.Tuple(R.Tensor((4, 16), dtype="float32"), 
R.Tuple(R.Object)) = take, (_io,)
+            R.output(gv1)
+        return gv1
+
+    mod = modules.Embedding(8, 16, "float32")
+    tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((4,), 
"int32")}}, debug=True)
+    assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
+def test_embedding_2d():
     @R.function
     def forward(
         x: R.Tensor((1, 4), dtype="int32"),

Reply via email to