This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 162d43a997 [Relax][PyTorch] Add support for torch.einsum (#17186)
162d43a997 is described below

commit 162d43a9978f3d31cfd48e3e0ad70ffbba5d22ec
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Tue Jul 23 13:23:12 2024 +0900

    [Relax][PyTorch] Add support for torch.einsum (#17186)
    
    Add torch.einsum support to Relax PyTorch Frontend.
---
 python/tvm/relax/frontend/torch/fx_translator.py |  9 +++++
 tests/python/relax/test_frontend_from_fx.py      | 43 ++++++++++++++++++++++++
 2 files changed, 52 insertions(+)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index f9a5d9c33f..e6b39c3eee 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -518,6 +518,14 @@ class TorchFXImporter:
             res = bias if res is None else 
self.block_builder.emit(relax.op.add(res, bias))
         return res
 
+    def _einsum(self, node: fx.node.Node) -> relax.Var:
+        import torch  # type: ignore
+
+        args = self.retrieve_args(node)
+        if isinstance(args[1], (torch.Size, tuple, list)):
+            return self.block_builder.emit(relax.op.einsum(tuple(args[1]), 
args[0]))
+        return self.block_builder.emit(relax.op.einsum(args[1:], args[0]))
+
     ########## Manipulation ##########
 
     def _cat(self, node: fx.node.Node) -> relax.Var:
@@ -1482,6 +1490,7 @@ class TorchFXImporter:
             "max": self._max,
             "cross_entropy": self._cross_entropy,
             "scaled_dot_product_attention": self._scaled_dot_product_attention,
+            "einsum": self._einsum,
         }
 
     def update_convert_map(self, custom_convert_map: dict):
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 46c079aa99..b4ac3fa60c 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -650,6 +650,49 @@ def test_baddbmm():
     )
 
 
+def test_einsum():
+    class Einsum1(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x):
+            return torch.einsum("ii", x)
+
+    class Einsum2(Module):
+        def __init__(self):
+            super().__init__()
+
+        def forward(self, x, y):
+            return torch.einsum("i,j->ij", x, y)
+
+    @tvm.script.ir_module
+    class Expected1:
+        @R.function
+        def main(inp_0: R.Tensor((4, 4), dtype="float32")) -> R.Tensor((), 
dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,), 
subscripts="ii")
+                gv: R.Tensor((), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected2:
+        @R.function
+        def main(
+            inp_0: R.Tensor((5,), dtype="float32"), inp_1: R.Tensor((4,), 
dtype="float32")
+        ) -> R.Tensor((5, 4), dtype="float32"):
+            with R.dataflow():
+                lv: R.Tensor((5, 4), dtype="float32") = R.einsum(
+                    (inp_0, inp_1), subscripts="i,j->ij"
+                )
+                gv: R.Tensor((5, 4), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Einsum1(), [([4, 4], "float32")], {}, Expected1)
+    verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {}, 
Expected2)
+
+
 def test_relu():
     class ReLU0(Module):
         def __init__(self):

Reply via email to