This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 162d43a997 [Relax][PyTorch] Add support for torch.einsum (#17186)
162d43a997 is described below
commit 162d43a9978f3d31cfd48e3e0ad70ffbba5d22ec
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Tue Jul 23 13:23:12 2024 +0900
[Relax][PyTorch] Add support for torch.einsum (#17186)
Add torch.einsum support to Relax PyTorch Frontend.
---
python/tvm/relax/frontend/torch/fx_translator.py | 9 +++++
tests/python/relax/test_frontend_from_fx.py | 43 ++++++++++++++++++++++++
2 files changed, 52 insertions(+)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index f9a5d9c33f..e6b39c3eee 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -518,6 +518,14 @@ class TorchFXImporter:
res = bias if res is None else
self.block_builder.emit(relax.op.add(res, bias))
return res
+ def _einsum(self, node: fx.node.Node) -> relax.Var:
+ import torch # type: ignore
+
+ args = self.retrieve_args(node)
+ if isinstance(args[1], (torch.Size, tuple, list)):
+ return self.block_builder.emit(relax.op.einsum(tuple(args[1]),
args[0]))
+ return self.block_builder.emit(relax.op.einsum(args[1:], args[0]))
+
########## Manipulation ##########
def _cat(self, node: fx.node.Node) -> relax.Var:
@@ -1482,6 +1490,7 @@ class TorchFXImporter:
"max": self._max,
"cross_entropy": self._cross_entropy,
"scaled_dot_product_attention": self._scaled_dot_product_attention,
+ "einsum": self._einsum,
}
def update_convert_map(self, custom_convert_map: dict):
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 46c079aa99..b4ac3fa60c 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -650,6 +650,49 @@ def test_baddbmm():
)
+def test_einsum():
+ class Einsum1(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return torch.einsum("ii", x)
+
+ class Einsum2(Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ return torch.einsum("i,j->ij", x, y)
+
+ @tvm.script.ir_module
+ class Expected1:
+ @R.function
+ def main(inp_0: R.Tensor((4, 4), dtype="float32")) -> R.Tensor((),
dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((), dtype="float32") = R.einsum((inp_0,),
subscripts="ii")
+ gv: R.Tensor((), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected2:
+ @R.function
+ def main(
+ inp_0: R.Tensor((5,), dtype="float32"), inp_1: R.Tensor((4,),
dtype="float32")
+ ) -> R.Tensor((5, 4), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((5, 4), dtype="float32") = R.einsum(
+ (inp_0, inp_1), subscripts="i,j->ij"
+ )
+ gv: R.Tensor((5, 4), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Einsum1(), [([4, 4], "float32")], {}, Expected1)
+ verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {},
Expected2)
+
+
def test_relu():
class ReLU0(Module):
def __init__(self):